module SimplCore ( core2core, simplifyExpr ) where
#include "HsVersions.h"
import GhcPrelude
import DynFlags
import CoreSyn
import HscTypes
import CSE ( cseProgram )
import Rules ( mkRuleBase, unionRuleBase,
extendRuleBaseList, ruleCheckProgram, addRuleInfo,
getRules )
import PprCore ( pprCoreBindings, pprCoreExpr )
import OccurAnal ( occurAnalysePgm, occurAnalyseExpr )
import IdInfo
import CoreStats ( coreBindsSize, coreBindsStats, exprSize )
import CoreUtils ( mkTicks, stripTicksTop )
import CoreLint ( endPass, lintPassResult, dumpPassResult,
lintAnnots )
import Simplify ( simplTopBinds, simplExpr, simplRules )
import SimplUtils ( simplEnvForGHCi, activeRule, activeUnfolding )
import SimplEnv
import SimplMonad
import CoreMonad
import qualified ErrUtils as Err
import FloatIn ( floatInwards )
import FloatOut ( floatOutwards )
import FamInstEnv
import Id
import ErrUtils ( withTiming, withTimingD )
import BasicTypes ( CompilerPhase(..), isDefaultInlinePragma, defaultInlinePragma )
import VarSet
import VarEnv
import LiberateCase ( liberateCase )
import SAT ( doStaticArgs )
import Specialise ( specProgram)
import SpecConstr ( specConstrProgram)
import DmdAnal ( dmdAnalProgram )
import CallArity ( callArityAnalProgram )
import Exitify ( exitifyProgram )
import WorkWrap ( wwTopBinds )
import SrcLoc
import Util
import Module
import Plugins ( withPlugins, installCoreToDos )
import DynamicLoading
import UniqSupply ( UniqSupply, mkSplitUniqSupply, splitUniqSupply )
import UniqFM
import Outputable
import Control.Monad
import qualified GHC.LanguageExtensions as LangExt
core2core :: HscEnv -> ModGuts -> IO ModGuts
core2core hsc_env guts@(ModGuts { mg_module = mod
, mg_loc = loc
, mg_deps = deps
, mg_rdr_env = rdr_env })
= do {
; let builtin_passes = getCoreToDo dflags
orph_mods = mkModuleSet (mod : dep_orphs deps)
uniq_mask = 's'
;
; (guts2, stats) <- runCoreM hsc_env hpt_rule_base uniq_mask mod
orph_mods print_unqual loc $
do { hsc_env' <- getHscEnv
; dflags' <- liftIO $ initializePlugins hsc_env'
(hsc_dflags hsc_env')
; all_passes <- withPlugins dflags'
installCoreToDos
builtin_passes
; runCorePasses all_passes guts }
; Err.dumpIfSet_dyn dflags Opt_D_dump_simpl_stats
"Grand total simplifier statistics"
(pprSimplCount stats)
; return guts2 }
where
dflags = hsc_dflags hsc_env
home_pkg_rules = hptRules hsc_env (dep_mods deps)
hpt_rule_base = mkRuleBase home_pkg_rules
print_unqual = mkPrintUnqualified dflags rdr_env
getCoreToDo :: DynFlags -> [CoreToDo]
getCoreToDo dflags
= flatten_todos core_todo
where
opt_level = optLevel dflags
phases = simplPhases dflags
max_iter = maxSimplIterations dflags
rule_check = ruleCheck dflags
call_arity = gopt Opt_CallArity dflags
exitification = gopt Opt_Exitification dflags
strictness = gopt Opt_Strictness dflags
full_laziness = gopt Opt_FullLaziness dflags
do_specialise = gopt Opt_Specialise dflags
do_float_in = gopt Opt_FloatIn dflags
cse = gopt Opt_CSE dflags
spec_constr = gopt Opt_SpecConstr dflags
liberate_case = gopt Opt_LiberateCase dflags
late_dmd_anal = gopt Opt_LateDmdAnal dflags
late_specialise = gopt Opt_LateSpecialise dflags
static_args = gopt Opt_StaticArgumentTransformation dflags
rules_on = gopt Opt_EnableRewriteRules dflags
eta_expand_on = gopt Opt_DoLambdaEtaExpansion dflags
ww_on = gopt Opt_WorkerWrapper dflags
static_ptrs = xopt LangExt.StaticPointers dflags
maybe_rule_check phase = runMaybe rule_check (CoreDoRuleCheck phase)
maybe_strictness_before phase
= runWhen (phase `elem` strictnessBefore dflags) CoreDoStrictness
base_mode = SimplMode { sm_phase = panic "base_mode"
, sm_names = []
, sm_dflags = dflags
, sm_rules = rules_on
, sm_eta_expand = eta_expand_on
, sm_inline = True
, sm_case_case = True }
simpl_phase phase names iter
= CoreDoPasses
$ [ maybe_strictness_before phase
, CoreDoSimplify iter
(base_mode { sm_phase = Phase phase
, sm_names = names })
, maybe_rule_check (Phase phase) ]
simpl_phases = CoreDoPasses [ simpl_phase phase ["main"] max_iter
| phase <- [phases, phases1 .. 1] ]
simpl_gently = CoreDoSimplify max_iter
(base_mode { sm_phase = InitialPhase
, sm_names = ["Gentle"]
, sm_rules = rules_on
, sm_inline = True
, sm_case_case = False })
strictness_pass = if ww_on
then [CoreDoStrictness,CoreDoWorkerWrapper]
else [CoreDoStrictness]
demand_analyser = (CoreDoPasses (
strictness_pass ++
[simpl_phase 0 ["post-worker-wrapper"] max_iter]
))
static_ptrs_float_outwards =
runWhen static_ptrs $ CoreDoPasses
[ simpl_gently
, CoreDoFloatOutwards FloatOutSwitches
{ floatOutLambdas = Just 0
, floatOutConstants = True
, floatOutOverSatApps = False
, floatToTopLevelOnly = True
}
]
core_todo =
if opt_level == 0 then
[ static_ptrs_float_outwards,
CoreDoSimplify max_iter
(base_mode { sm_phase = Phase 0
, sm_names = ["Non-opt simplification"] })
]
else [
runWhen static_args (CoreDoPasses [ simpl_gently, CoreDoStaticArgs ]),
simpl_gently,
runWhen do_specialise CoreDoSpecialising,
if full_laziness then
CoreDoFloatOutwards FloatOutSwitches {
floatOutLambdas = Just 0,
floatOutConstants = True,
floatOutOverSatApps = False,
floatToTopLevelOnly = False }
else
static_ptrs_float_outwards,
simpl_phases,
simpl_phase 0 ["main"] (max max_iter 3),
runWhen do_float_in CoreDoFloatInwards,
runWhen call_arity $ CoreDoPasses
[ CoreDoCallArity
, simpl_phase 0 ["post-call-arity"] max_iter
],
runWhen strictness demand_analyser,
runWhen exitification CoreDoExitify,
runWhen full_laziness $
CoreDoFloatOutwards FloatOutSwitches {
floatOutLambdas = floatLamArgs dflags,
floatOutConstants = True,
floatOutOverSatApps = True,
floatToTopLevelOnly = False },
runWhen cse CoreCSE,
runWhen do_float_in CoreDoFloatInwards,
maybe_rule_check (Phase 0),
runWhen liberate_case (CoreDoPasses [
CoreLiberateCase,
simpl_phase 0 ["post-liberate-case"] max_iter
]),
runWhen spec_constr CoreDoSpecConstr,
maybe_rule_check (Phase 0),
runWhen late_specialise
(CoreDoPasses [ CoreDoSpecialising
, simpl_phase 0 ["post-late-spec"] max_iter]),
runWhen ((liberate_case || spec_constr) && cse) CoreCSE,
simpl_phase 0 ["final"] max_iter,
runWhen late_dmd_anal $ CoreDoPasses (
strictness_pass ++
[simpl_phase 0 ["post-late-ww"] max_iter]
),
runWhen (strictness || late_dmd_anal) CoreDoStrictness,
maybe_rule_check (Phase 0)
]
flatten_todos [] = []
flatten_todos (CoreDoNothing : rest) = flatten_todos rest
flatten_todos (CoreDoPasses passes : rest) =
flatten_todos passes ++ flatten_todos rest
flatten_todos (todo : rest) = todo : flatten_todos rest
runCorePasses :: [CoreToDo] -> ModGuts -> CoreM ModGuts
runCorePasses passes guts
= foldM do_pass guts passes
where
do_pass guts CoreDoNothing = return guts
do_pass guts (CoreDoPasses ps) = runCorePasses ps guts
do_pass guts pass = do
withTimingD (ppr pass <+> brackets (ppr mod))
(const ()) $ do
{ guts' <- lintAnnots (ppr pass) (doCorePass pass) guts
; endPass pass (mg_binds guts') (mg_rules guts')
; return guts' }
mod = mg_module guts
doCorePass :: CoreToDo -> ModGuts -> CoreM ModGuts
doCorePass pass@(CoreDoSimplify {}) =
simplifyPgm pass
doCorePass CoreCSE =
doPass cseProgram
doCorePass CoreLiberateCase =
doPassD liberateCase
doCorePass CoreDoFloatInwards =
floatInwards
doCorePass (CoreDoFloatOutwards f) =
doPassDUM (floatOutwards f)
doCorePass CoreDoStaticArgs =
doPassU doStaticArgs
doCorePass CoreDoCallArity =
doPassD callArityAnalProgram
doCorePass CoreDoExitify =
doPass exitifyProgram
doCorePass CoreDoStrictness =
doPassDFM dmdAnalProgram
doCorePass CoreDoWorkerWrapper =
doPassDFU wwTopBinds
doCorePass CoreDoSpecialising =
specProgram
doCorePass CoreDoSpecConstr =
specConstrProgram
doCorePass CoreDoPrintCore = observe printCore
doCorePass (CoreDoRuleCheck phase pat) = ruleCheckPass phase pat
doCorePass CoreDoNothing = return
doCorePass (CoreDoPasses passes) = runCorePasses passes
doCorePass (CoreDoPluginPass _ pass) = pass
doCorePass pass@CoreDesugar = pprPanic "doCorePass" (ppr pass)
doCorePass pass@CoreDesugarOpt = pprPanic "doCorePass" (ppr pass)
doCorePass pass@CoreTidy = pprPanic "doCorePass" (ppr pass)
doCorePass pass@CorePrep = pprPanic "doCorePass" (ppr pass)
doCorePass pass@CoreOccurAnal = pprPanic "doCorePass" (ppr pass)
printCore :: DynFlags -> CoreProgram -> IO ()
printCore dflags binds
= Err.dumpIfSet dflags True "Print Core" (pprCoreBindings binds)
ruleCheckPass :: CompilerPhase -> String -> ModGuts -> CoreM ModGuts
ruleCheckPass current_phase pat guts =
withTimingD (text "RuleCheck"<+>brackets (ppr $ mg_module guts))
(const ()) $ do
{ rb <- getRuleBase
; dflags <- getDynFlags
; vis_orphs <- getVisibleOrphanMods
; let rule_fn fn = getRules (RuleEnv rb vis_orphs) fn
++ (mg_rules guts)
; liftIO $ putLogMsg dflags NoReason Err.SevDump noSrcSpan
(defaultDumpStyle dflags)
(ruleCheckProgram current_phase pat
rule_fn (mg_binds guts))
; return guts }
doPassDUM :: (DynFlags -> UniqSupply -> CoreProgram -> IO CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDUM do_pass = doPassM $ \binds -> do
dflags <- getDynFlags
us <- getUniqueSupplyM
liftIO $ do_pass dflags us binds
doPassDM :: (DynFlags -> CoreProgram -> IO CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDM do_pass = doPassDUM (\dflags -> const (do_pass dflags))
doPassD :: (DynFlags -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassD do_pass = doPassDM (\dflags -> return . do_pass dflags)
doPassDU :: (DynFlags -> UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDU do_pass = doPassDUM (\dflags us -> return . do_pass dflags us)
doPassU :: (UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassU do_pass = doPassDU (const do_pass)
doPassDFM :: (DynFlags -> FamInstEnvs -> CoreProgram -> IO CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDFM do_pass guts = do
dflags <- getDynFlags
p_fam_env <- getPackageFamInstEnv
let fam_envs = (p_fam_env, mg_fam_inst_env guts)
doPassM (liftIO . do_pass dflags fam_envs) guts
doPassDFU :: (DynFlags -> FamInstEnvs -> UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDFU do_pass guts = do
dflags <- getDynFlags
us <- getUniqueSupplyM
p_fam_env <- getPackageFamInstEnv
let fam_envs = (p_fam_env, mg_fam_inst_env guts)
doPass (do_pass dflags fam_envs us) guts
doPassM :: Monad m => (CoreProgram -> m CoreProgram) -> ModGuts -> m ModGuts
doPassM bind_f guts = do
binds' <- bind_f (mg_binds guts)
return (guts { mg_binds = binds' })
doPass :: (CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPass bind_f guts = return $ guts { mg_binds = bind_f (mg_binds guts) }
observe :: (DynFlags -> CoreProgram -> IO a) -> ModGuts -> CoreM ModGuts
observe do_pass = doPassM $ \binds -> do
dflags <- getDynFlags
_ <- liftIO $ do_pass dflags binds
return binds
simplifyExpr :: DynFlags
-> CoreExpr
-> IO CoreExpr
simplifyExpr dflags expr
= withTiming dflags (text "Simplify [expr]") (const ()) $
do {
; us <- mkSplitUniqSupply 's'
; let sz = exprSize expr
; (expr', counts) <- initSmpl dflags emptyRuleEnv
emptyFamInstEnvs us sz
(simplExprGently (simplEnvForGHCi dflags) expr)
; Err.dumpIfSet dflags (dopt Opt_D_dump_simpl_stats dflags)
"Simplifier statistics" (pprSimplCount counts)
; Err.dumpIfSet_dyn dflags Opt_D_dump_simpl "Simplified expression"
(pprCoreExpr expr')
; return expr'
}
simplExprGently :: SimplEnv -> CoreExpr -> SimplM CoreExpr
simplExprGently env expr = do
expr1 <- simplExpr env (occurAnalyseExpr expr)
simplExpr env (occurAnalyseExpr expr1)
simplifyPgm :: CoreToDo -> ModGuts -> CoreM ModGuts
simplifyPgm pass guts
= do { hsc_env <- getHscEnv
; us <- getUniqueSupplyM
; rb <- getRuleBase
; liftIOWithCount $
simplifyPgmIO pass hsc_env us rb guts }
simplifyPgmIO :: CoreToDo
-> HscEnv
-> UniqSupply
-> RuleBase
-> ModGuts
-> IO (SimplCount, ModGuts)
simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
hsc_env us hpt_rule_base
guts@(ModGuts { mg_module = this_mod
, mg_rdr_env = rdr_env
, mg_deps = deps
, mg_binds = binds, mg_rules = rules
, mg_fam_inst_env = fam_inst_env })
= do { (termination_msg, it_count, counts_out, guts')
<- do_iteration us 1 [] binds rules
; Err.dumpIfSet dflags (dopt Opt_D_verbose_core2core dflags &&
dopt Opt_D_dump_simpl_stats dflags)
"Simplifier statistics for following pass"
(vcat [text termination_msg <+> text "after" <+> ppr it_count
<+> text "iterations",
blankLine,
pprSimplCount counts_out])
; return (counts_out, guts')
}
where
dflags = hsc_dflags hsc_env
print_unqual = mkPrintUnqualified dflags rdr_env
simpl_env = mkSimplEnv mode
active_rule = activeRule mode
active_unf = activeUnfolding mode
do_iteration :: UniqSupply
-> Int
-> [SimplCount]
-> CoreProgram
-> [CoreRule]
-> IO (String, Int, SimplCount, ModGuts)
do_iteration us iteration_no counts_so_far binds rules
| iteration_no > max_iterations
= WARN( debugIsOn && (max_iterations > 2)
, hang (text "Simplifier bailing out after" <+> int max_iterations
<+> text "iterations"
<+> (brackets $ hsep $ punctuate comma $
map (int . simplCountN) (reverse counts_so_far)))
2 (text "Size =" <+> ppr (coreBindsStats binds)))
return ( "Simplifier baled out", iteration_no 1
, totalise counts_so_far
, guts { mg_binds = binds, mg_rules = rules } )
| let sz = coreBindsSize binds
, () <- sz `seq` ()
= do {
let { tagged_binds =
occurAnalysePgm this_mod active_unf active_rule rules
binds
} ;
Err.dumpIfSet_dyn dflags Opt_D_dump_occur_anal "Occurrence analysis"
(pprCoreBindings tagged_binds);
eps <- hscEPS hsc_env ;
let { rule_base1 = unionRuleBase hpt_rule_base (eps_rule_base eps)
; rule_base2 = extendRuleBaseList rule_base1 rules
; fam_envs = (eps_fam_inst_env eps, fam_inst_env)
; vis_orphs = this_mod : dep_orphs deps } ;
((binds1, rules1), counts1) <-
initSmpl dflags (mkRuleEnv rule_base2 vis_orphs) fam_envs us1 sz $
do { (floats, env1) <-
simplTopBinds simpl_env tagged_binds
; rules1 <- simplRules env1 Nothing rules Nothing
; return (getTopFloatBinds floats, rules1) } ;
if isZeroSimplCount counts1 then
return ( "Simplifier reached fixed point", iteration_no
, totalise (counts1 : counts_so_far)
, guts { mg_binds = binds1, mg_rules = rules1 } )
else do {
let { binds2 = shortOutIndirections binds1 } ;
dump_end_iteration dflags print_unqual iteration_no counts1 binds2 rules1 ;
lintPassResult hsc_env pass binds2 ;
do_iteration us2 (iteration_no + 1) (counts1:counts_so_far) binds2 rules1
} }
| otherwise = panic "do_iteration"
where
(us1, us2) = splitUniqSupply us
totalise :: [SimplCount] -> SimplCount
totalise = foldr (\c acc -> acc `plusSimplCount` c)
(zeroSimplCount dflags)
simplifyPgmIO _ _ _ _ _ = panic "simplifyPgmIO"
dump_end_iteration :: DynFlags -> PrintUnqualified -> Int
-> SimplCount -> CoreProgram -> [CoreRule] -> IO ()
dump_end_iteration dflags print_unqual iteration_no counts binds rules
= dumpPassResult dflags print_unqual mb_flag hdr pp_counts binds rules
where
mb_flag | dopt Opt_D_dump_simpl_iterations dflags = Just Opt_D_dump_simpl_iterations
| otherwise = Nothing
hdr = text "Simplifier iteration=" <> int iteration_no
pp_counts = vcat [ text "---- Simplifier counts for" <+> hdr
, pprSimplCount counts
, text "---- End of simplifier counts for" <+> hdr ]
type IndEnv = IdEnv (Id, [Tickish Var])
shortOutIndirections :: CoreProgram -> CoreProgram
shortOutIndirections binds
| isEmptyVarEnv ind_env = binds
| no_need_to_flatten = binds'
| otherwise = [Rec (flattenBinds binds')]
where
ind_env = makeIndEnv binds
exp_ids = map fst $ nonDetEltsUFM ind_env
exp_id_set = mkVarSet exp_ids
no_need_to_flatten = all (null . ruleInfoRules . idSpecialisation) exp_ids
binds' = concatMap zap binds
zap (NonRec bndr rhs) = [NonRec b r | (b,r) <- zapPair (bndr,rhs)]
zap (Rec pairs) = [Rec (concatMap zapPair pairs)]
zapPair (bndr, rhs)
| bndr `elemVarSet` exp_id_set
= []
| Just (exp_id, ticks) <- lookupVarEnv ind_env bndr
, (exp_id', lcl_id') <- transferIdInfo exp_id bndr
=
[ (exp_id', mkTicks ticks rhs),
(lcl_id', Var exp_id') ]
| otherwise
= [(bndr,rhs)]
makeIndEnv :: [CoreBind] -> IndEnv
makeIndEnv binds
= foldl' add_bind emptyVarEnv binds
where
add_bind :: IndEnv -> CoreBind -> IndEnv
add_bind env (NonRec exported_id rhs) = add_pair env (exported_id, rhs)
add_bind env (Rec pairs) = foldl' add_pair env pairs
add_pair :: IndEnv -> (Id,CoreExpr) -> IndEnv
add_pair env (exported_id, exported)
| (ticks, Var local_id) <- stripTicksTop tickishFloatable exported
, shortMeOut env exported_id local_id
= extendVarEnv env local_id (exported_id, ticks)
add_pair env _ = env
shortMeOut :: IndEnv -> Id -> Id -> Bool
shortMeOut ind_env exported_id local_id
= if isExportedId exported_id &&
isLocalId local_id &&
not (isExportedId local_id) &&
not (local_id `elemVarEnv` ind_env)
then
if hasShortableIdInfo exported_id
then True
else WARN( True, text "Not shorting out:" <+> ppr exported_id )
False
else
False
hasShortableIdInfo :: Id -> Bool
hasShortableIdInfo id
= isEmptyRuleInfo (ruleInfo info)
&& isDefaultInlinePragma (inlinePragInfo info)
&& not (isStableUnfolding (unfoldingInfo info))
where
info = idInfo id
transferIdInfo :: Id -> Id -> (Id, Id)
transferIdInfo exported_id local_id
= ( modifyIdInfo transfer exported_id
, local_id `setInlinePragma` defaultInlinePragma )
where
local_info = idInfo local_id
transfer exp_info = exp_info `setStrictnessInfo` strictnessInfo local_info
`setUnfoldingInfo` unfoldingInfo local_info
`setInlinePragInfo` inlinePragInfo local_info
`setRuleInfo` addRuleInfo (ruleInfo exp_info) new_info
new_info = setRuleInfoHead (idName exported_id)
(ruleInfo local_info)