module FloatOut ( floatOutwards ) where
import CoreSyn
import CoreUtils
import MkCore
import CoreArity ( etaExpand )
import CoreMonad ( FloatOutSwitches(..) )
import DynFlags
import ErrUtils ( dumpIfSet_dyn )
import Id ( Id, idArity, isBottomingId )
import Var ( Var )
import SetLevels
import UniqSupply ( UniqSupply )
import Bag
import Util
import Maybes
import Outputable
import FastString
import qualified Data.IntMap as M
#include "HsVersions.h"
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> CoreProgram -> IO CoreProgram
floatOutwards float_sws dflags us pgm
= do {
let { annotated_w_levels = setLevels float_sws pgm us ;
(fss, binds_s') = unzip (map floatTopBind annotated_w_levels)
} ;
dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
(vcat (map ppr annotated_w_levels));
let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
(hcat [ int tlets, ptext (sLit " Lets floated to top level; "),
int ntlets, ptext (sLit " Lets floated elsewhere; from "),
int lams, ptext (sLit " Lambda groups")]);
return (bagToList (unionManyBags binds_s'))
}
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind bind
= case (floatBind bind) of { (fs, floats, bind') ->
let float_bag = flattenTopFloats floats
in case bind' of
Rec prs -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs)))
NonRec {} -> (fs, float_bag `snocBag` bind') }
floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreBind)
floatBind (NonRec (TB var _) rhs)
= case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
| otherwise = rhs'
in (fs, rhs_floats, NonRec var rhs'') }
floatBind (Rec pairs)
= case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
(fs, rhs_floats, Rec (concat new_pairs)) }
where
do_pair (TB name spec, rhs)
| isTopLvl dest_lvl
= case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
(fs, emptyFloats, addTopFloatPairs (flattenTopFloats rhs_floats) [(name, rhs')])}
| otherwise
= case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
case (partitionByLevel dest_lvl rhs_floats) of { (rhs_floats', heres) ->
case (splitRecFloats heres) of { (pairs, case_heres) ->
(fs, rhs_floats', (name, installUnderLambdas case_heres rhs') : pairs) }}}
where
dest_lvl = floatSpecLevel spec
splitRecFloats :: Bag FloatBind -> ([(Id,CoreExpr)], Bag FloatBind)
splitRecFloats fs
= go [] (bagToList fs)
where
go prs (FloatLet (NonRec b r) : fs) = go ((b,r):prs) fs
go prs (FloatLet (Rec prs') : fs) = go (prs' ++ prs) fs
go prs fs = (prs, listToBag fs)
installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
installUnderLambdas floats e
| isEmptyBag floats = e
| otherwise = go e
where
go (Lam b e) = Lam b (go e)
go e = install floats e
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, emptyFloats, [])
floatList f (a:as) = case f a of { (fs_a, binds_a, b) ->
case floatList f as of { (fs_as, binds_as, bs) ->
(fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }}
floatBody :: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatBody lvl arg
= case (floatExpr arg) of { (fsa, floats, arg') ->
case (partitionByLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatExpr :: LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatExpr (Var v) = (zeroStats, emptyFloats, Var v)
floatExpr (Type ty) = (zeroStats, emptyFloats, Type ty)
floatExpr (Coercion co) = (zeroStats, emptyFloats, Coercion co)
floatExpr (Lit lit) = (zeroStats, emptyFloats, Lit lit)
floatExpr (App e a)
= case (floatExpr e) of { (fse, floats_e, e') ->
case (floatExpr a) of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
floatExpr lam@(Lam (TB _ lam_spec) _)
= let (bndrs_w_lvls, body) = collectBinders lam
bndrs = [b | TB b _ <- bndrs_w_lvls]
bndr_lvl = floatSpecLevel lam_spec
in
case (floatBody bndr_lvl body) of { (fs, floats, body') ->
(add_to_stats fs floats, floats, mkLams bndrs body') }
floatExpr (Tick tickish expr)
| tickish `tickishScopesLike` SoftScope
= case (floatExpr expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Tick tickish expr') }
| not (tickishCounts tickish) || tickishCanSplit tickish
= case (floatExpr expr) of { (fs, floating_defns, expr') ->
let
annotated_defns = wrapTick (mkNoCount tickish) floating_defns
in
(fs, annotated_defns, Tick tickish expr') }
| otherwise
= pprPanic "floatExpr tick" (ppr tickish)
floatExpr (Cast expr co)
= case (floatExpr expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Cast expr' co) }
floatExpr (Let bind body)
= case bind_spec of
FloatMe dest_lvl
-> case (floatBind bind) of { (fsb, bind_floats, bind') ->
case (floatExpr body) of { (fse, body_floats, body') ->
( add_stats fsb fse
, bind_floats `plusFloats` unitLetFloat dest_lvl bind'
`plusFloats` body_floats
, body') }}
StayPut bind_lvl
-> case (floatBind bind) of { (fsb, bind_floats, bind') ->
case (floatBody bind_lvl body) of { (fse, body_floats, body') ->
( add_stats fsb fse
, bind_floats `plusFloats` body_floats
, Let bind' body') }}
where
bind_spec = case bind of
NonRec (TB _ s) _ -> s
Rec ((TB _ s, _) : _) -> s
Rec [] -> panic "floatExpr:rec"
floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
= case case_spec of
FloatMe dest_lvl
| [(con@(DataAlt {}), bndrs, rhs)] <- alts
-> case floatExpr scrut of { (fse, fde, scrut') ->
case floatExpr rhs of { (fsb, fdb, rhs') ->
let
float = unitCaseFloat dest_lvl scrut'
case_bndr con [b | TB b _ <- bndrs]
in
(add_stats fse fsb, fde `plusFloats` float `plusFloats` fdb, rhs') }}
| otherwise
-> pprPanic "Floating multi-case" (ppr alts)
StayPut bind_lvl
-> case floatExpr scrut of { (fse, fde, scrut') ->
case floatList (float_alt bind_lvl) alts of { (fsa, fda, alts') ->
(add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
}}
where
float_alt bind_lvl (con, bs, rhs)
= case (floatBody bind_lvl rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
data FloatStats
= FlS Int
Int
Int
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)
zeroStats :: FloatStats
zeroStats = FlS 0 0 0
sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
= FlS (a1 + a2) (b1 + b2) (c1 + c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a b c) (FB tops others)
= FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
type FloatLet = CoreBind
type MajorEnv = M.IntMap MinorEnv
type MinorEnv = M.IntMap (Bag FloatBind)
data FloatBinds = FB !(Bag FloatLet)
!MajorEnv
instance Outputable FloatBinds where
ppr (FB fbs defs)
= ptext (sLit "FB") <+> (braces $ vcat
[ ptext (sLit "tops =") <+> ppr fbs
, ptext (sLit "non-tops =") <+> ppr defs ])
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB tops defs)
= ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
tops
addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs float_bag prs
= foldrBag add prs float_bag
where
add (NonRec b r) prs = (b,r):prs
add (Rec prs1) prs2 = prs1 ++ prs2
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = M.fold (unionBags . flattenMinor) emptyBag
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = M.fold unionBags emptyBag
emptyFloats :: FloatBinds
emptyFloats = FB emptyBag M.empty
unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat (Level major minor) e b con bs
= FB emptyBag (M.singleton major (M.singleton minor (unitBag (FloatCase e b con bs))))
unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat lvl@(Level major minor) b
| isTopLvl lvl = FB (unitBag b) M.empty
| otherwise = FB emptyBag (M.singleton major (M.singleton minor floats))
where
floats = unitBag (FloatLet b)
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 l1) (FB t2 l2)
= FB (t1 `unionBags` t2) (l1 `plusMajor` l2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = M.unionWith plusMinor
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = M.unionWith unionBags
install :: Bag FloatBind -> CoreExpr -> CoreExpr
install defn_groups expr
= foldrBag wrapFloat expr defn_groups
partitionByLevel
:: Level
-> FloatBinds
-> (FloatBinds,
Bag FloatBind)
partitionByLevel (Level major minor) (FB tops defns)
= (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
here_min `unionBags` flattenMinor inner_min
`unionBags` flattenMajor inner_maj)
where
(outer_maj, mb_here_maj, inner_maj) = M.splitLookup major defns
(outer_min, mb_here_min, inner_min) = case mb_here_maj of
Nothing -> (M.empty, Nothing, M.empty)
Just min_defns -> M.splitLookup minor min_defns
here_min = mb_here_min `orElse` emptyBag
wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
wrapTick t (FB tops defns)
= FB (mapBag wrap_bind tops) (M.map (M.map wrap_defns) defns)
where
wrap_defns = mapBag wrap_one
wrap_bind (NonRec binder rhs) = NonRec binder (maybe_tick rhs)
wrap_bind (Rec pairs) = Rec (mapSnd maybe_tick pairs)
wrap_one (FloatLet bind) = FloatLet (wrap_bind bind)
wrap_one (FloatCase e b c bs) = FloatCase (maybe_tick e) b c bs
maybe_tick e | exprIsHNF e = tickHNFArgs t e
| otherwise = mkTick t e