module GHC.Core.Opt.Monad (
CoreToDo(..), runWhen, runMaybe,
SimplMode(..),
FloatOutSwitches(..),
pprPassDetails,
CorePluginPass, bindsOnlyPass,
SimplCount, doSimplTick, doFreeSimplTick, simplCountN,
pprSimplCount, plusSimplCount, zeroSimplCount,
isZeroSimplCount, hasDetailedCounts, Tick(..),
CoreM, runCoreM,
getHscEnv, getRuleBase, getModule,
getDynFlags, getPackageFamInstEnv,
getVisibleOrphanMods, getUniqMask,
getPrintUnqualified, getSrcSpanM,
addSimplCount,
liftIO, liftIOWithCount,
getAnnotations, getFirstAnnotations,
putMsg, putMsgS, errorMsg, errorMsgS, warnMsg,
fatalErrorMsg, fatalErrorMsgS,
debugTraceMsg, debugTraceMsgS,
dumpIfSet_dyn
) where
import GHC.Prelude hiding ( read )
import GHC.Core
import GHC.Driver.Types
import GHC.Unit.Module
import GHC.Driver.Session
import GHC.Types.Basic ( CompilerPhase(..) )
import GHC.Types.Annotations
import GHC.Data.IOEnv hiding ( liftIO, failM, failWithM )
import qualified GHC.Data.IOEnv as IOEnv
import GHC.Types.Var
import GHC.Utils.Outputable as Outputable
import GHC.Data.FastString
import GHC.Utils.Error( Severity(..), DumpFormat (..), dumpOptionsFromFlag )
import GHC.Types.Unique.Supply
import GHC.Utils.Monad
import GHC.Types.Name.Env
import GHC.Types.SrcLoc
import Data.Bifunctor ( bimap )
import GHC.Utils.Error (dumpAction)
import Data.List (intersperse, groupBy, sortBy)
import Data.Ord
import Data.Dynamic
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Strict as MapStrict
import Data.Word
import Control.Monad
import Control.Applicative ( Alternative(..) )
import GHC.Utils.Panic (throwGhcException, GhcException(..))
data CoreToDo
= CoreDoSimplify
Int
SimplMode
| CoreDoPluginPass String CorePluginPass
| CoreDoFloatInwards
| CoreDoFloatOutwards FloatOutSwitches
| CoreLiberateCase
| CoreDoPrintCore
| CoreDoStaticArgs
| CoreDoCallArity
| CoreDoExitify
| CoreDoDemand
| CoreDoCpr
| CoreDoWorkerWrapper
| CoreDoSpecialising
| CoreDoSpecConstr
| CoreCSE
| CoreDoRuleCheck CompilerPhase String
| CoreDoNothing
| CoreDoPasses [CoreToDo]
| CoreDesugar
| CoreDesugarOpt
| CoreTidy
| CorePrep
| CoreOccurAnal
instance Outputable CoreToDo where
ppr (CoreDoSimplify _ _) = text "Simplifier"
ppr (CoreDoPluginPass s _) = text "Core plugin: " <+> text s
ppr CoreDoFloatInwards = text "Float inwards"
ppr (CoreDoFloatOutwards f) = text "Float out" <> parens (ppr f)
ppr CoreLiberateCase = text "Liberate case"
ppr CoreDoStaticArgs = text "Static argument"
ppr CoreDoCallArity = text "Called arity analysis"
ppr CoreDoExitify = text "Exitification transformation"
ppr CoreDoDemand = text "Demand analysis"
ppr CoreDoCpr = text "Constructed Product Result analysis"
ppr CoreDoWorkerWrapper = text "Worker Wrapper binds"
ppr CoreDoSpecialising = text "Specialise"
ppr CoreDoSpecConstr = text "SpecConstr"
ppr CoreCSE = text "Common sub-expression"
ppr CoreDesugar = text "Desugar (before optimization)"
ppr CoreDesugarOpt = text "Desugar (after optimization)"
ppr CoreTidy = text "Tidy Core"
ppr CorePrep = text "CorePrep"
ppr CoreOccurAnal = text "Occurrence analysis"
ppr CoreDoPrintCore = text "Print core"
ppr (CoreDoRuleCheck {}) = text "Rule check"
ppr CoreDoNothing = text "CoreDoNothing"
ppr (CoreDoPasses passes) = text "CoreDoPasses" <+> ppr passes
pprPassDetails :: CoreToDo -> SDoc
pprPassDetails (CoreDoSimplify n md) = vcat [ text "Max iterations =" <+> int n
, ppr md ]
pprPassDetails _ = Outputable.empty
data SimplMode
= SimplMode
{ sm_names :: [String]
, sm_phase :: CompilerPhase
, sm_dflags :: DynFlags
, sm_rules :: Bool
, sm_inline :: Bool
, sm_case_case :: Bool
, sm_eta_expand :: Bool
}
instance Outputable SimplMode where
ppr (SimplMode { sm_phase = p, sm_names = ss
, sm_rules = r, sm_inline = i
, sm_eta_expand = eta, sm_case_case = cc })
= text "SimplMode" <+> braces (
sep [ text "Phase =" <+> ppr p <+>
brackets (text (concat $ intersperse "," ss)) <> comma
, pp_flag i (sLit "inline") <> comma
, pp_flag r (sLit "rules") <> comma
, pp_flag eta (sLit "eta-expand") <> comma
, pp_flag cc (sLit "case-of-case") ])
where
pp_flag f s = ppUnless f (text "no") <+> ptext s
data FloatOutSwitches = FloatOutSwitches {
floatOutLambdas :: Maybe Int,
floatOutConstants :: Bool,
floatOutOverSatApps :: Bool,
floatToTopLevelOnly :: Bool
}
instance Outputable FloatOutSwitches where
ppr = pprFloatOutSwitches
pprFloatOutSwitches :: FloatOutSwitches -> SDoc
pprFloatOutSwitches sw
= text "FOS" <+> (braces $
sep $ punctuate comma $
[ text "Lam =" <+> ppr (floatOutLambdas sw)
, text "Consts =" <+> ppr (floatOutConstants sw)
, text "OverSatApps =" <+> ppr (floatOutOverSatApps sw) ])
runWhen :: Bool -> CoreToDo -> CoreToDo
runWhen True do_this = do_this
runWhen False _ = CoreDoNothing
runMaybe :: Maybe a -> (a -> CoreToDo) -> CoreToDo
runMaybe (Just x) f = f x
runMaybe Nothing _ = CoreDoNothing
type CorePluginPass = ModGuts -> CoreM ModGuts
bindsOnlyPass :: (CoreProgram -> CoreM CoreProgram) -> ModGuts -> CoreM ModGuts
bindsOnlyPass pass guts
= do { binds' <- pass (mg_binds guts)
; return (guts { mg_binds = binds' }) }
getVerboseSimplStats :: (Bool -> SDoc) -> SDoc
getVerboseSimplStats = getPprDebug
zeroSimplCount :: DynFlags -> SimplCount
isZeroSimplCount :: SimplCount -> Bool
hasDetailedCounts :: SimplCount -> Bool
pprSimplCount :: SimplCount -> SDoc
doSimplTick :: DynFlags -> Tick -> SimplCount -> SimplCount
doFreeSimplTick :: Tick -> SimplCount -> SimplCount
plusSimplCount :: SimplCount -> SimplCount -> SimplCount
data SimplCount
= VerySimplCount !Int
| SimplCount {
ticks :: !Int,
details :: !TickCounts,
n_log :: !Int,
log1 :: [Tick],
log2 :: [Tick]
}
type TickCounts = Map Tick Int
simplCountN :: SimplCount -> Int
simplCountN (VerySimplCount n) = n
simplCountN (SimplCount { ticks = n }) = n
zeroSimplCount dflags
| dopt Opt_D_dump_simpl_stats dflags
= SimplCount {ticks = 0, details = Map.empty,
n_log = 0, log1 = [], log2 = []}
| otherwise
= VerySimplCount 0
isZeroSimplCount (VerySimplCount n) = n==0
isZeroSimplCount (SimplCount { ticks = n }) = n==0
hasDetailedCounts (VerySimplCount {}) = False
hasDetailedCounts (SimplCount {}) = True
doFreeSimplTick tick sc@SimplCount { details = dts }
= sc { details = dts `addTick` tick }
doFreeSimplTick _ sc = sc
doSimplTick dflags tick
sc@(SimplCount { ticks = tks, details = dts, n_log = nl, log1 = l1 })
| nl >= historySize dflags = sc1 { n_log = 1, log1 = [tick], log2 = l1 }
| otherwise = sc1 { n_log = nl+1, log1 = tick : l1 }
where
sc1 = sc { ticks = tks+1, details = dts `addTick` tick }
doSimplTick _ _ (VerySimplCount n) = VerySimplCount (n+1)
addTick :: TickCounts -> Tick -> TickCounts
addTick fm tick = MapStrict.insertWith (+) tick 1 fm
plusSimplCount sc1@(SimplCount { ticks = tks1, details = dts1 })
sc2@(SimplCount { ticks = tks2, details = dts2 })
= log_base { ticks = tks1 + tks2
, details = MapStrict.unionWith (+) dts1 dts2 }
where
log_base | null (log1 sc2) = sc1
| null (log2 sc2) = sc2 { log2 = log1 sc1 }
| otherwise = sc2
plusSimplCount (VerySimplCount n) (VerySimplCount m) = VerySimplCount (n+m)
plusSimplCount lhs rhs =
throwGhcException . PprProgramError "plusSimplCount" $ vcat
[ text "lhs"
, pprSimplCount lhs
, text "rhs"
, pprSimplCount rhs
]
pprSimplCount (VerySimplCount n) = text "Total ticks:" <+> int n
pprSimplCount (SimplCount { ticks = tks, details = dts, log1 = l1, log2 = l2 })
= vcat [text "Total ticks: " <+> int tks,
blankLine,
pprTickCounts dts,
getVerboseSimplStats $ \dbg -> if dbg
then
vcat [blankLine,
text "Log (most recent first)",
nest 4 (vcat (map ppr l1) $$ vcat (map ppr l2))]
else Outputable.empty
]
pprTickCounts :: Map Tick Int -> SDoc
pprTickCounts counts
= vcat (map pprTickGroup groups)
where
groups :: [[(Tick,Int)]]
groups = groupBy same_tag (Map.toList counts)
same_tag (tick1,_) (tick2,_) = tickToTag tick1 == tickToTag tick2
pprTickGroup :: [(Tick, Int)] -> SDoc
pprTickGroup group@((tick1,_):_)
= hang (int (sum [n | (_,n) <- group]) <+> text (tickString tick1))
2 (vcat [ int n <+> pprTickCts tick
| (tick,n) <- sortBy (flip (comparing snd)) group])
pprTickGroup [] = panic "pprTickGroup"
data Tick
= PreInlineUnconditionally Id
| PostInlineUnconditionally Id
| UnfoldingDone Id
| RuleFired FastString
| LetFloatFromLet
| EtaExpansion Id
| EtaReduction Id
| BetaReduction Id
| CaseOfCase Id
| KnownBranch Id
| CaseMerge Id
| AltMerge Id
| CaseElim Id
| CaseIdentity Id
| FillInCaseDefault Id
| SimplifierDone
instance Outputable Tick where
ppr tick = text (tickString tick) <+> pprTickCts tick
instance Eq Tick where
a == b = case a `cmpTick` b of
EQ -> True
_ -> False
instance Ord Tick where
compare = cmpTick
tickToTag :: Tick -> Int
tickToTag (PreInlineUnconditionally _) = 0
tickToTag (PostInlineUnconditionally _) = 1
tickToTag (UnfoldingDone _) = 2
tickToTag (RuleFired _) = 3
tickToTag LetFloatFromLet = 4
tickToTag (EtaExpansion _) = 5
tickToTag (EtaReduction _) = 6
tickToTag (BetaReduction _) = 7
tickToTag (CaseOfCase _) = 8
tickToTag (KnownBranch _) = 9
tickToTag (CaseMerge _) = 10
tickToTag (CaseElim _) = 11
tickToTag (CaseIdentity _) = 12
tickToTag (FillInCaseDefault _) = 13
tickToTag SimplifierDone = 16
tickToTag (AltMerge _) = 17
tickString :: Tick -> String
tickString (PreInlineUnconditionally _) = "PreInlineUnconditionally"
tickString (PostInlineUnconditionally _)= "PostInlineUnconditionally"
tickString (UnfoldingDone _) = "UnfoldingDone"
tickString (RuleFired _) = "RuleFired"
tickString LetFloatFromLet = "LetFloatFromLet"
tickString (EtaExpansion _) = "EtaExpansion"
tickString (EtaReduction _) = "EtaReduction"
tickString (BetaReduction _) = "BetaReduction"
tickString (CaseOfCase _) = "CaseOfCase"
tickString (KnownBranch _) = "KnownBranch"
tickString (CaseMerge _) = "CaseMerge"
tickString (AltMerge _) = "AltMerge"
tickString (CaseElim _) = "CaseElim"
tickString (CaseIdentity _) = "CaseIdentity"
tickString (FillInCaseDefault _) = "FillInCaseDefault"
tickString SimplifierDone = "SimplifierDone"
pprTickCts :: Tick -> SDoc
pprTickCts (PreInlineUnconditionally v) = ppr v
pprTickCts (PostInlineUnconditionally v)= ppr v
pprTickCts (UnfoldingDone v) = ppr v
pprTickCts (RuleFired v) = ppr v
pprTickCts LetFloatFromLet = Outputable.empty
pprTickCts (EtaExpansion v) = ppr v
pprTickCts (EtaReduction v) = ppr v
pprTickCts (BetaReduction v) = ppr v
pprTickCts (CaseOfCase v) = ppr v
pprTickCts (KnownBranch v) = ppr v
pprTickCts (CaseMerge v) = ppr v
pprTickCts (AltMerge v) = ppr v
pprTickCts (CaseElim v) = ppr v
pprTickCts (CaseIdentity v) = ppr v
pprTickCts (FillInCaseDefault v) = ppr v
pprTickCts _ = Outputable.empty
cmpTick :: Tick -> Tick -> Ordering
cmpTick a b = case (tickToTag a `compare` tickToTag b) of
GT -> GT
EQ -> cmpEqTick a b
LT -> LT
cmpEqTick :: Tick -> Tick -> Ordering
cmpEqTick (PreInlineUnconditionally a) (PreInlineUnconditionally b) = a `compare` b
cmpEqTick (PostInlineUnconditionally a) (PostInlineUnconditionally b) = a `compare` b
cmpEqTick (UnfoldingDone a) (UnfoldingDone b) = a `compare` b
cmpEqTick (RuleFired a) (RuleFired b) = a `compare` b
cmpEqTick (EtaExpansion a) (EtaExpansion b) = a `compare` b
cmpEqTick (EtaReduction a) (EtaReduction b) = a `compare` b
cmpEqTick (BetaReduction a) (BetaReduction b) = a `compare` b
cmpEqTick (CaseOfCase a) (CaseOfCase b) = a `compare` b
cmpEqTick (KnownBranch a) (KnownBranch b) = a `compare` b
cmpEqTick (CaseMerge a) (CaseMerge b) = a `compare` b
cmpEqTick (AltMerge a) (AltMerge b) = a `compare` b
cmpEqTick (CaseElim a) (CaseElim b) = a `compare` b
cmpEqTick (CaseIdentity a) (CaseIdentity b) = a `compare` b
cmpEqTick (FillInCaseDefault a) (FillInCaseDefault b) = a `compare` b
cmpEqTick _ _ = EQ
data CoreReader = CoreReader {
cr_hsc_env :: HscEnv,
cr_rule_base :: RuleBase,
cr_module :: Module,
cr_print_unqual :: PrintUnqualified,
cr_loc :: SrcSpan,
cr_visible_orphan_mods :: !ModuleSet,
cr_uniq_mask :: !Char
}
newtype CoreWriter = CoreWriter {
cw_simpl_count :: SimplCount
}
emptyWriter :: DynFlags -> CoreWriter
emptyWriter dflags = CoreWriter {
cw_simpl_count = zeroSimplCount dflags
}
plusWriter :: CoreWriter -> CoreWriter -> CoreWriter
plusWriter w1 w2 = CoreWriter {
cw_simpl_count = (cw_simpl_count w1) `plusSimplCount` (cw_simpl_count w2)
}
type CoreIOEnv = IOEnv CoreReader
newtype CoreM a = CoreM { unCoreM :: CoreIOEnv (a, CoreWriter) }
deriving (Functor)
instance Monad CoreM where
mx >>= f = CoreM $ do
(x, w1) <- unCoreM mx
(y, w2) <- unCoreM (f x)
let w = w1 `plusWriter` w2
return $ seq w (y, w)
instance Applicative CoreM where
pure x = CoreM $ nop x
(<*>) = ap
m *> k = m >>= \_ -> k
instance Alternative CoreM where
empty = CoreM Control.Applicative.empty
m <|> n = CoreM (unCoreM m <|> unCoreM n)
instance MonadPlus CoreM
instance MonadUnique CoreM where
getUniqueSupplyM = do
mask <- read cr_uniq_mask
liftIO $! mkSplitUniqSupply mask
getUniqueM = do
mask <- read cr_uniq_mask
liftIO $! uniqFromMask mask
runCoreM :: HscEnv
-> RuleBase
-> Char
-> Module
-> ModuleSet
-> PrintUnqualified
-> SrcSpan
-> CoreM a
-> IO (a, SimplCount)
runCoreM hsc_env rule_base mask mod orph_imps print_unqual loc m
= liftM extract $ runIOEnv reader $ unCoreM m
where
reader = CoreReader {
cr_hsc_env = hsc_env,
cr_rule_base = rule_base,
cr_module = mod,
cr_visible_orphan_mods = orph_imps,
cr_print_unqual = print_unqual,
cr_loc = loc,
cr_uniq_mask = mask
}
extract :: (a, CoreWriter) -> (a, SimplCount)
extract (value, writer) = (value, cw_simpl_count writer)
nop :: a -> CoreIOEnv (a, CoreWriter)
nop x = do
r <- getEnv
return (x, emptyWriter $ (hsc_dflags . cr_hsc_env) r)
read :: (CoreReader -> a) -> CoreM a
read f = CoreM $ getEnv >>= (\r -> nop (f r))
write :: CoreWriter -> CoreM ()
write w = CoreM $ return ((), w)
liftIOEnv :: CoreIOEnv a -> CoreM a
liftIOEnv mx = CoreM (mx >>= (\x -> nop x))
instance MonadIO CoreM where
liftIO = liftIOEnv . IOEnv.liftIO
liftIOWithCount :: IO (SimplCount, a) -> CoreM a
liftIOWithCount what = liftIO what >>= (\(count, x) -> addSimplCount count >> return x)
getHscEnv :: CoreM HscEnv
getHscEnv = read cr_hsc_env
getRuleBase :: CoreM RuleBase
getRuleBase = read cr_rule_base
getVisibleOrphanMods :: CoreM ModuleSet
getVisibleOrphanMods = read cr_visible_orphan_mods
getPrintUnqualified :: CoreM PrintUnqualified
getPrintUnqualified = read cr_print_unqual
getSrcSpanM :: CoreM SrcSpan
getSrcSpanM = read cr_loc
addSimplCount :: SimplCount -> CoreM ()
addSimplCount count = write (CoreWriter { cw_simpl_count = count })
getUniqMask :: CoreM Char
getUniqMask = read cr_uniq_mask
instance HasDynFlags CoreM where
getDynFlags = fmap hsc_dflags getHscEnv
instance HasModule CoreM where
getModule = read cr_module
getPackageFamInstEnv :: CoreM PackageFamInstEnv
getPackageFamInstEnv = do
hsc_env <- getHscEnv
eps <- liftIO $ hscEPS hsc_env
return $ eps_fam_inst_env eps
getAnnotations :: Typeable a => ([Word8] -> a) -> ModGuts -> CoreM (ModuleEnv [a], NameEnv [a])
getAnnotations deserialize guts = do
hsc_env <- getHscEnv
ann_env <- liftIO $ prepareAnnotations hsc_env (Just guts)
return (deserializeAnns deserialize ann_env)
getFirstAnnotations :: Typeable a => ([Word8] -> a) -> ModGuts -> CoreM (ModuleEnv a, NameEnv a)
getFirstAnnotations deserialize guts
= bimap mod name <$> getAnnotations deserialize guts
where
mod = mapModuleEnv head . filterModuleEnv (const $ not . null)
name = mapNameEnv head . filterNameEnv (not . null)
msg :: Severity -> WarnReason -> SDoc -> CoreM ()
msg sev reason doc
= do { dflags <- getDynFlags
; loc <- getSrcSpanM
; unqual <- getPrintUnqualified
; let sty = case sev of
SevError -> err_sty
SevWarning -> err_sty
SevDump -> dump_sty
_ -> user_sty
err_sty = mkErrStyle unqual
user_sty = mkUserStyle unqual AllTheWay
dump_sty = mkDumpStyle unqual
; liftIO $ putLogMsg dflags reason sev loc (withPprStyle sty doc) }
putMsgS :: String -> CoreM ()
putMsgS = putMsg . text
putMsg :: SDoc -> CoreM ()
putMsg = msg SevInfo NoReason
errorMsgS :: String -> CoreM ()
errorMsgS = errorMsg . text
errorMsg :: SDoc -> CoreM ()
errorMsg = msg SevError NoReason
warnMsg :: WarnReason -> SDoc -> CoreM ()
warnMsg = msg SevWarning
fatalErrorMsgS :: String -> CoreM ()
fatalErrorMsgS = fatalErrorMsg . text
fatalErrorMsg :: SDoc -> CoreM ()
fatalErrorMsg = msg SevFatal NoReason
debugTraceMsgS :: String -> CoreM ()
debugTraceMsgS = debugTraceMsg . text
debugTraceMsg :: SDoc -> CoreM ()
debugTraceMsg = msg SevDump NoReason
dumpIfSet_dyn :: DumpFlag -> String -> DumpFormat -> SDoc -> CoreM ()
dumpIfSet_dyn flag str fmt doc
= do { dflags <- getDynFlags
; unqual <- getPrintUnqualified
; when (dopt flag dflags) $ liftIO $ do
let sty = mkDumpStyle unqual
dumpAction dflags sty (dumpOptionsFromFlag flag) str fmt doc }