%
% (c) The GRASP/AQUA Project, Glasgow University, 19921998
%
\section[FloatOut]{Float bindings outwards (towards the top level)}
``Longdistance'' floating of bindings towards the top level.
\begin{code}
module FloatOut ( floatOutwards ) where
import CoreSyn
import CoreUtils
import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils ( dumpIfSet_dyn )
import CostCentre ( dupifyCC, CostCentre )
import Id ( Id, idType )
import Type ( isUnLiftedType )
import SetLevels ( Level(..), LevelledExpr, LevelledBind,
setLevels, isTopLvl, tOP_LEVEL )
import UniqSupply ( UniqSupply )
import Bag
import Util
import Maybes
import UniqFM
import Outputable
import FastString
\end{code}
Overall game plan
The Big Main Idea is:
To float out subexpressions that can thereby get outside
a nononeshot value lambda, and hence may be shared.
To achieve this we may need to do two thing:
a) Letbind the subexpression:
f (g x) ==> let lvl = f (g x) in lvl
Now we can float the binding for 'lvl'.
b) More than that, we may need to abstract wrt a type variable
\x -> ... /\a -> let v = ...a... in ....
Here the binding for v mentions 'a' but not 'x'. So we
abstract wrt 'a', to give this binding for 'v':
vp = /\a -> ...a...
v = vp a
Now the binding for vp can float out unimpeded.
I can't remember why this case seemed important enough to
deal with, but I certainly found cases where important floats
didn't happen if we did not abstract wrt tyvars.
With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level. Whether or not this
happens is controlled by a flag.
Random comments
~~~~~~~~~~~~~~~
At the moment we never float a binding out to between two adjacent
lambdas. For example:
@
\x y -> let t = x+x in ...
===>
\x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.
But there's a case I've seen where this might not be true. Consider:
@
elEm2 x ys
= elem' x ys
where
elem' _ [] = False
elem' x (y:ys) = x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
\deq x ys -> let eq = eqFromEqDict deq in ...
@
vwhich might usefully be separated to
@
\deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe. We don't do this at the moment.
%************************************************************************
%* *
\subsection[floatOutwards]{@floatOutwards@: letfloating interface function}
%* *
%************************************************************************
\begin{code}
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> [CoreBind] -> IO [CoreBind]
floatOutwards float_sws dflags us pgm
= do {
let { annotated_w_levels = setLevels float_sws pgm us ;
(fss, binds_s') = unzip (map floatTopBind annotated_w_levels)
} ;
dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
(vcat (map ppr annotated_w_levels));
let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
(hcat [ int tlets, ptext (sLit " Lets floated to top level; "),
int ntlets, ptext (sLit " Lets floated elsewhere; from "),
int lams, ptext (sLit " Lambda groups")]);
return (concat binds_s')
}
floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
floatTopBind bind
= case (floatBind bind) of { (fs, floats) ->
(fs, bagToList (flattenFloats floats))
}
\end{code}
%************************************************************************
%* *
\subsection[FloatOutBind]{Floating in a binding (the business end)}
%* *
%************************************************************************
\begin{code}
floatBind :: LevelledBind -> (FloatStats, FloatBinds)
floatBind (NonRec (TB name level) rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats `plusFloats` unitFloat level (NonRec name rhs')) }
floatBind bind@(Rec pairs)
= case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
let rhs_floats = foldr1 plusFloats rhss_floats in
if not (isTopLvl bind_dest_lvl) then
case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
(sum_stats fss,
floats' `plusFloats` unitFloat bind_dest_lvl
(Rec (floatsToBindPairs heres new_pairs))) }
else
(sum_stats fss, unitFloat tOP_LEVEL
(Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
}
where
bind_dest_lvl = getBindLevel bind
do_pair (TB name level, rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (name, rhs'))
}
\end{code}
%************************************************************************
\subsection[FloatOutExpr]{Floating in expressions}
%* *
%************************************************************************
\begin{code}
floatExpr, floatRhs, floatCaseAlt
:: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatCaseAlt lvl arg
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatRhs lvl arg
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
if exprIsCheap arg' then
(fsa, floats, arg')
else
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatExpr _ (Var v) = (zeroStats, emptyFloats, Var v)
floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
floatExpr lvl (App e a)
= case (floatExpr lvl e) of { (fse, floats_e, e') ->
case (floatRhs lvl a) of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
floatExpr _ lam@(Lam _ _)
= let
(bndrs_w_lvls, body) = collectBinders lam
bndrs = [b | TB b _ <- bndrs_w_lvls]
lvls = [l | TB _ l <- bndrs_w_lvls]
partition_fn | all isTyVar bndrs = partitionByLevel
| otherwise = partitionByMajorLevel
in
case (floatExpr (last lvls) body) of { (fs, floats, body') ->
case (partition_fn (head lvls) floats) of { (floats', heres) ->
(add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
}}
floatExpr lvl (Note note@(SCC cc) expr)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
let
annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
in
(fs, annotated_defns, Note note expr') }
floatExpr _ (Note InlineMe expr)
= (zeroStats, emptyFloats, Note InlineMe (unTag expr))
floatExpr lvl (Note note expr)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Note note expr') }
floatExpr lvl (Cast expr co)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Cast expr' co) }
floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
| isUnLiftedType (idType bndr)
= case floatExpr lvl rhs of { (_, rhs_floats, rhs') ->
case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
(fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
floatExpr lvl (Let bind body)
= case (floatBind bind) of { (fsb, bind_floats) ->
case (floatExpr lvl body) of { (fse, body_floats, body') ->
(add_stats fsb fse,
bind_floats `plusFloats` body_floats,
body') }}
floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
= case floatExpr lvl scrut of { (fse, fde, scrut') ->
case floatList float_alt alts of { (fsa, fda, alts') ->
(add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
}}
where
float_alt (con, bs, rhs)
= case (floatCaseAlt case_lvl rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, emptyFloats, [])
floatList f (a:as) = case f a of { (fs_a, binds_a, b) ->
case floatList f as of { (fs_as, binds_as, bs) ->
(fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }}
getBindLevel :: Bind (TaggedBndr Level) -> Level
getBindLevel (NonRec (TB _ lvl) _) = lvl
getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
getBindLevel (Rec []) = panic "getBindLevel Rec []"
unTagBndr :: TaggedBndr tag -> CoreBndr
unTagBndr (TB b _) = b
unTag :: TaggedExpr tag -> CoreExpr
unTag (Var v) = Var v
unTag (Lit l) = Lit l
unTag (Type ty) = Type ty
unTag (Note n e) = Note n (unTag e)
unTag (App e1 e2) = App (unTag e1) (unTag e2)
unTag (Lam b e) = Lam (unTagBndr b) (unTag e)
unTag (Cast e co) = Cast (unTag e) co
unTag (Let (Rec prs) e) = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
unTag (Case e b ty alts) = Case (unTag e) (unTagBndr b) ty
[(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
\end{code}
%************************************************************************
%* *
\subsection{Utility bits for floating stats}
%* *
%************************************************************************
I didn't implement this with unboxed numbers. I don't want to be too
strict in this stuff, as it is rarely turned on. (WDP 95/09)
\begin{code}
data FloatStats
= FlS Int
Int
Int
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)
zeroStats :: FloatStats
zeroStats = FlS 0 0 0
sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
= FlS (a1 + a2) (b1 + b2) (c1 + c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a b c) (FB tops others)
= FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
\end{code}
%************************************************************************
%* *
\subsection{Utility bits for floating}
%* *
%************************************************************************
Note [Representation of FloatBinds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The FloatBinds types is somewhat important. We can get very large numbers
of floating bindings, often all destined for the top level. A typical example
is x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.
The trouble is that
(a) we partition these floating bindings *at every binding site*
(b) SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!
That is why MajorEnv is represented as a finite map.
We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see
partitionByMajorLevel.
\begin{code}
type FloatBind = CoreBind
data FloatBinds = FB !(Bag FloatBind)
!MajorEnv
type MajorEnv = UniqFM MinorEnv
type MinorEnv = UniqFM (Bag FloatBind)
flattenFloats :: FloatBinds -> Bag FloatBind
flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = foldUFM unionBags emptyBag
emptyFloats :: FloatBinds
emptyFloats = FB emptyBag emptyUFM
unitFloat :: Level -> FloatBind -> FloatBinds
unitFloat InlineCtxt b = FB (unitBag b) emptyUFM
unitFloat lvl@(Level major minor) b
| isTopLvl lvl = FB (unitBag b) emptyUFM
| otherwise = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = plusUFM_C plusMinor
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = plusUFM_C unionBags
floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
floatsToBindPairs floats binds = foldrBag add binds floats
where
add (Rec pairs) binds = pairs ++ binds
add (NonRec binder rhs) binds = (binder,rhs) : binds
install :: Bag FloatBind -> CoreExpr -> CoreExpr
install defn_groups expr
= foldrBag install_group expr defn_groups
where
install_group defns body = Let defns body
partitionByMajorLevel, partitionByLevel
:: Level
-> FloatBinds
-> (FloatBinds,
Bag FloatBind)
partitionByMajorLevel InlineCtxt (FB tops defns)
= (FB tops emptyUFM, flattenMajor defns)
partitionByMajorLevel (Level major _) (FB tops defns)
= (FB tops outer, heres `unionBags` flattenMajor inner)
where
(outer, mb_heres, inner) = splitUFM defns major
heres = case mb_heres of
Nothing -> emptyBag
Just h -> flattenMinor h
partitionByLevel InlineCtxt (FB tops defns)
= (FB tops emptyUFM, flattenMajor defns)
partitionByLevel (Level major minor) (FB tops defns)
= (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
here_min `unionBags` flattenMinor inner_min
`unionBags` flattenMajor inner_maj)
where
(outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
(outer_min, mb_here_min, inner_min) = case mb_here_maj of
Nothing -> (emptyUFM, Nothing, emptyUFM)
Just min_defns -> splitUFM min_defns minor
here_min = mb_here_min `orElse` emptyBag
wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
wrapCostCentre cc (FB tops defns)
= FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
where
wrap_defns = mapBag wrap_one
wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
wrap_one (Rec pairs) = Rec (mapSnd (mkSCC cc) pairs)
\end{code}