{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TupleSections #-}

-- | Adds cost-centers after the core piple has run.
module GHC.Core.LateCC
    ( addLateCostCentresMG
    , addLateCostCentresPgm
    , addLateCostCentres -- Might be useful for API users
    , Env(..)
    ) where

import Control.Applicative
import GHC.Utils.Monad.State.Strict
import Control.Monad

import GHC.Prelude
import GHC.Driver.Session
import GHC.Types.CostCentre
import GHC.Types.CostCentre.State
import GHC.Types.Name hiding (varName)
import GHC.Types.Tickish
import GHC.Unit.Module.ModGuts
import GHC.Types.Var
import GHC.Unit.Types
import GHC.Data.FastString
import GHC.Core
import GHC.Core.Opt.Monad
import GHC.Types.Id
import GHC.Core.Utils (mkTick)

import qualified Data.Set as S
import GHC.Utils.Logger
import GHC.Utils.Outputable
import GHC.Utils.Misc
import GHC.Utils.Error (withTiming)


{- Note [Collecting late cost centres]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Usually cost centres defined by a module are collected
during tidy by collectCostCentres. However with `-fprof-late`
we insert cost centres after inlining. So we keep a list of
all the cost centres we inserted and combine that with the list
of cost centres found during tidy.

To avoid overhead when using -fprof-inline there is a flag to stop
us from collecting them here when we run this pass before tidy.

Note [Adding late cost centres]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The basic idea is very simple. For every top level binder
`f = rhs` we compile it as if the user had written
`f = {-# SCC f #-} rhs`.

If we do this after unfoldings for `f` have been created this
doesn't impact core-level optimizations at all. If we do it
before the cost centre will be included in the unfolding and
might inhibit optimizations at the call site. For this reason
we provide flags for both approaches as they have different
tradeoffs.

We also don't add a cost centre for any binder that is a constructor
worker or wrapper. These will never meaningfully enrich the resulting
profile so we improve efficiency by omitting those.

-}

addLateCostCentresMG :: ModGuts -> CoreM ModGuts
addLateCostCentresMG :: ModGuts -> CoreM ModGuts
addLateCostCentresMG ModGuts
guts = do
  DynFlags
dflags <- CoreM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
  let env :: Env
      env :: Env
env = Env
        { thisModule :: Module
thisModule = ModGuts -> Module
mg_module ModGuts
guts
        , ccState :: CostCentreState
ccState = CostCentreState
newCostCentreState
        , countEntries :: Bool
countEntries = GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_ProfCountEntries DynFlags
dflags
        , collectCCs :: Bool
collectCCs = Bool
False -- See Note [Collecting late cost centres]
        }
  let guts' :: ModGuts
guts' = ModGuts
guts { mg_binds :: CoreProgram
mg_binds = (CoreProgram, Set CostCentre) -> CoreProgram
forall a b. (a, b) -> a
fst (Env -> CoreProgram -> (CoreProgram, Set CostCentre)
addLateCostCentres Env
env (ModGuts -> CoreProgram
mg_binds ModGuts
guts))
                   }
  ModGuts -> CoreM ModGuts
forall a. a -> CoreM a
forall (m :: * -> *) a. Monad m => a -> m a
return ModGuts
guts'

addLateCostCentresPgm :: DynFlags -> Logger -> Module -> CoreProgram -> IO (CoreProgram, S.Set CostCentre)
addLateCostCentresPgm :: DynFlags
-> Logger
-> Module
-> CoreProgram
-> IO (CoreProgram, Set CostCentre)
addLateCostCentresPgm DynFlags
dflags Logger
logger Module
mod CoreProgram
binds =
  Logger
-> SDoc
-> ((CoreProgram, Set CostCentre) -> ())
-> IO (CoreProgram, Set CostCentre)
-> IO (CoreProgram, Set CostCentre)
forall (m :: * -> *) a.
MonadIO m =>
Logger -> SDoc -> (a -> ()) -> m a -> m a
withTiming Logger
logger
               (String -> SDoc
text String
"LateCC"SDoc -> SDoc -> SDoc
<+>SDoc -> SDoc
brackets (Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
mod))
               (\(CoreProgram
a,Set CostCentre
b) -> CoreProgram
a CoreProgram -> () -> ()
forall a b. [a] -> b -> b
`seqList` (Set CostCentre
b Set CostCentre -> () -> ()
forall a b. a -> b -> b
`seq` ())) (IO (CoreProgram, Set CostCentre)
 -> IO (CoreProgram, Set CostCentre))
-> IO (CoreProgram, Set CostCentre)
-> IO (CoreProgram, Set CostCentre)
forall a b. (a -> b) -> a -> b
$ do
  let env :: Env
env = Env
        { thisModule :: Module
thisModule = Module
mod
        , ccState :: CostCentreState
ccState = CostCentreState
newCostCentreState
        , countEntries :: Bool
countEntries = GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_ProfCountEntries DynFlags
dflags
        , collectCCs :: Bool
collectCCs = Bool
True -- See Note [Collecting late cost centres]
        }
      (CoreProgram
binds', Set CostCentre
ccs) = Env -> CoreProgram -> (CoreProgram, Set CostCentre)
addLateCostCentres Env
env CoreProgram
binds
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DumpFlag -> DynFlags -> Bool
dopt DumpFlag
Opt_D_dump_late_cc DynFlags
dflags Bool -> Bool -> Bool
|| DumpFlag -> DynFlags -> Bool
dopt DumpFlag
Opt_D_verbose_core2core DynFlags
dflags) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_dump_late_cc String
"LateCC" DumpFormat
FormatCore ([SDoc] -> SDoc
vcat ((CoreBind -> SDoc) -> CoreProgram -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreProgram
binds'))
  (CoreProgram, Set CostCentre) -> IO (CoreProgram, Set CostCentre)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreProgram
binds', Set CostCentre
ccs)

addLateCostCentres :: Env -> CoreProgram -> (CoreProgram,S.Set CostCentre)
addLateCostCentres :: Env -> CoreProgram -> (CoreProgram, Set CostCentre)
addLateCostCentres Env
env CoreProgram
binds =
  let (CoreProgram
binds', LateCCState
state) = State LateCCState CoreProgram
-> LateCCState -> (CoreProgram, LateCCState)
forall s a. State s a -> s -> (a, s)
runState ((CoreBind -> State LateCCState CoreBind)
-> CoreProgram -> State LateCCState CoreProgram
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Env -> CoreBind -> State LateCCState CoreBind
doBind Env
env) CoreProgram
binds) LateCCState
initLateCCState
  in (CoreProgram
binds',LateCCState -> Set CostCentre
lcs_ccs LateCCState
state)


doBind :: Env -> CoreBind -> M CoreBind
doBind :: Env -> CoreBind -> State LateCCState CoreBind
doBind Env
env (NonRec Var
b Expr Var
rhs) = Var -> Expr Var -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Var
b (Expr Var -> CoreBind)
-> State LateCCState (Expr Var) -> State LateCCState CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr Env
env Var
b Expr Var
rhs
doBind Env
env (Rec [(Var, Expr Var)]
bs) = [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(Var, Expr Var)] -> CoreBind)
-> State LateCCState [(Var, Expr Var)]
-> State LateCCState CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Var, Expr Var) -> State LateCCState (Var, Expr Var))
-> [(Var, Expr Var)] -> State LateCCState [(Var, Expr Var)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Var, Expr Var) -> State LateCCState (Var, Expr Var)
doPair [(Var, Expr Var)]
bs
  where
    doPair :: ((Id, CoreExpr) -> M (Id, CoreExpr))
    doPair :: (Var, Expr Var) -> State LateCCState (Var, Expr Var)
doPair (Var
b,Expr Var
rhs) = (Var
b,) (Expr Var -> (Var, Expr Var))
-> State LateCCState (Expr Var)
-> State LateCCState (Var, Expr Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr Env
env Var
b Expr Var
rhs

doBndr :: Env -> Id -> CoreExpr -> M CoreExpr
doBndr :: Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr Env
env Var
bndr Expr Var
rhs
  -- Cost centres on constructor workers are pretty much useless
  -- so we don't emit them if we are looking at the rhs of a constructor
  -- binding.
  | Just DataCon
_ <- Var -> Maybe DataCon
isDataConId_maybe Var
bndr = Expr Var -> State LateCCState (Expr Var)
forall a. a -> State LateCCState a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Var
rhs
  | Bool
otherwise = Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr' Env
env Var
bndr Expr Var
rhs


-- We want to put the cost centra below the lambda as we only care about executions of the RHS.
doBndr' :: Env -> Id -> CoreExpr -> State LateCCState CoreExpr
doBndr' :: Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr' Env
env Var
bndr (Lam Var
b Expr Var
rhs) = Var -> Expr Var -> Expr Var
forall b. b -> Expr b -> Expr b
Lam Var
b (Expr Var -> Expr Var)
-> State LateCCState (Expr Var) -> State LateCCState (Expr Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env -> Var -> Expr Var -> State LateCCState (Expr Var)
doBndr' Env
env Var
bndr Expr Var
rhs
doBndr' Env
env Var
bndr Expr Var
rhs = do
    let name :: Name
name = Var -> Name
idName Var
bndr
        name_loc :: SrcSpan
name_loc = Name -> SrcSpan
nameSrcSpan Name
name
        cc_name :: FastString
cc_name = Name -> FastString
forall a. NamedThing a => a -> FastString
getOccFS Name
name
        count :: Bool
count = Env -> Bool
countEntries Env
env
    CCFlavour
cc_flavour <- FastString -> M CCFlavour
getCCFlavour FastString
cc_name
    let cc_mod :: Module
cc_mod = Env -> Module
thisModule Env
env
        bndrCC :: CostCentre
bndrCC = CCFlavour -> FastString -> Module -> SrcSpan -> CostCentre
NormalCC CCFlavour
cc_flavour FastString
cc_name Module
cc_mod SrcSpan
name_loc
        note :: GenTickish 'TickishPassCore
note = CostCentre -> Bool -> Bool -> GenTickish 'TickishPassCore
forall (pass :: TickishPass).
CostCentre -> Bool -> Bool -> GenTickish pass
ProfNote CostCentre
bndrCC Bool
count Bool
True
    Env -> CostCentre -> M ()
addCC Env
env CostCentre
bndrCC
    Expr Var -> State LateCCState (Expr Var)
forall a. a -> State LateCCState a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Var -> State LateCCState (Expr Var))
-> Expr Var -> State LateCCState (Expr Var)
forall a b. (a -> b) -> a -> b
$ GenTickish 'TickishPassCore -> Expr Var -> Expr Var
mkTick GenTickish 'TickishPassCore
note Expr Var
rhs

data LateCCState = LateCCState
    { LateCCState -> CostCentreState
lcs_state :: !CostCentreState
    , LateCCState -> Set CostCentre
lcs_ccs   :: S.Set CostCentre
    }
type M = State LateCCState

initLateCCState :: LateCCState
initLateCCState :: LateCCState
initLateCCState = CostCentreState -> Set CostCentre -> LateCCState
LateCCState CostCentreState
newCostCentreState Set CostCentre
forall a. Monoid a => a
mempty

getCCFlavour :: FastString -> M CCFlavour
getCCFlavour :: FastString -> M CCFlavour
getCCFlavour FastString
name = CostCentreIndex -> CCFlavour
LateCC (CostCentreIndex -> CCFlavour)
-> State LateCCState CostCentreIndex -> M CCFlavour
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FastString -> State LateCCState CostCentreIndex
getCCIndex' FastString
name

getCCIndex' :: FastString -> M CostCentreIndex
getCCIndex' :: FastString -> State LateCCState CostCentreIndex
getCCIndex' FastString
name = do
  LateCCState
state <- State LateCCState LateCCState
forall s. State s s
get
  let (CostCentreIndex
index,CostCentreState
cc_state') = FastString -> CostCentreState -> (CostCentreIndex, CostCentreState)
getCCIndex FastString
name (LateCCState -> CostCentreState
lcs_state LateCCState
state)
  LateCCState -> M ()
forall s. s -> State s ()
put (LateCCState
state { lcs_state :: CostCentreState
lcs_state = CostCentreState
cc_state'})
  CostCentreIndex -> State LateCCState CostCentreIndex
forall a. a -> State LateCCState a
forall (m :: * -> *) a. Monad m => a -> m a
return CostCentreIndex
index

addCC :: Env -> CostCentre -> M ()
addCC :: Env -> CostCentre -> M ()
addCC !Env
env CostCentre
cc = do
    LateCCState
state <- State LateCCState LateCCState
forall s. State s s
get
    Bool -> M () -> M ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Env -> Bool
collectCCs Env
env) (M () -> M ()) -> M () -> M ()
forall a b. (a -> b) -> a -> b
$ do
        let ccs' :: Set CostCentre
ccs' = CostCentre -> Set CostCentre -> Set CostCentre
forall a. Ord a => a -> Set a -> Set a
S.insert CostCentre
cc (LateCCState -> Set CostCentre
lcs_ccs LateCCState
state)
        LateCCState -> M ()
forall s. s -> State s ()
put (LateCCState
state { lcs_ccs :: Set CostCentre
lcs_ccs = Set CostCentre
ccs'})

data Env = Env
  { Env -> Module
thisModule  :: !Module
  , Env -> Bool
countEntries:: !Bool
  , Env -> CostCentreState
ccState     :: !CostCentreState
  , Env -> Bool
collectCCs  :: !Bool
  }