%
% (c) The GRASP/AQUA Project, Glasgow University, 19921998
%
\section[FloatOut]{Float bindings outwards (towards the top level)}
``Longdistance'' floating of bindings towards the top level.
\begin{code}
module FloatOut ( floatOutwards ) where
import CoreSyn
import CoreUtils
import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils ( dumpIfSet_dyn )
import CostCentre ( dupifyCC, CostCentre )
import Id ( Id, idType )
import Type ( isUnLiftedType )
import SetLevels ( Level(..), LevelledExpr, LevelledBind,
setLevels, ltMajLvl, ltLvl, isTopLvl )
import UniqSupply ( UniqSupply )
import Data.List
import Outputable
import FastString
\end{code}
Overall game plan
The Big Main Idea is:
To float out subexpressions that can thereby get outside
a nononeshot value lambda, and hence may be shared.
To achieve this we may need to do two thing:
a) Letbind the subexpression:
f (g x) ==> let lvl = f (g x) in lvl
Now we can float the binding for 'lvl'.
b) More than that, we may need to abstract wrt a type variable
\x -> ... /\a -> let v = ...a... in ....
Here the binding for v mentions 'a' but not 'x'. So we
abstract wrt 'a', to give this binding for 'v':
vp = /\a -> ...a...
v = vp a
Now the binding for vp can float out unimpeded.
I can't remember why this case seemed important enough to
deal with, but I certainly found cases where important floats
didn't happen if we did not abstract wrt tyvars.
With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level. Whether or not this
happens is controlled by a flag.
Random comments
~~~~~~~~~~~~~~~
At the moment we never float a binding out to between two adjacent
lambdas. For example:
@
\x y -> let t = x+x in ...
===>
\x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.
But there's a case I've seen where this might not be true. Consider:
@
elEm2 x ys
= elem' x ys
where
elem' _ [] = False
elem' x (y:ys) = x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
\deq x ys -> let eq = eqFromEqDict deq in ...
@
vwhich might usefully be separated to
@
\deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe. We don't do this at the moment.
\begin{code}
type FloatBind = (Level, CoreBind)
type FloatBinds = [FloatBind]
\end{code}
%************************************************************************
%* *
\subsection[floatOutwards]{@floatOutwards@: letfloating interface function}
%* *
%************************************************************************
\begin{code}
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> [CoreBind] -> IO [CoreBind]
floatOutwards float_sws dflags us pgm
= do {
let { annotated_w_levels = setLevels float_sws pgm us ;
(fss, binds_s') = unzip (map floatTopBind annotated_w_levels)
} ;
dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
(vcat (map ppr annotated_w_levels));
let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
(hcat [ int tlets, ptext (sLit " Lets floated to top level; "),
int ntlets, ptext (sLit " Lets floated elsewhere; from "),
int lams, ptext (sLit " Lambda groups")]);
return (concat binds_s')
}
floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
floatTopBind bind
= case (floatBind bind) of { (fs, floats) ->
(fs, floatsToBinds floats)
}
\end{code}
%************************************************************************
%* *
\subsection[FloatOutBind]{Floating in a binding (the business end)}
%* *
%************************************************************************
\begin{code}
floatBind :: LevelledBind -> (FloatStats, FloatBinds)
floatBind (NonRec (TB name level) rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats ++ [(level, NonRec name rhs')]) }
floatBind bind@(Rec pairs)
= case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
let rhs_floats = concat rhss_floats in
if not (isTopLvl bind_dest_lvl) then
case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
(sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
else
(sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
}
where
bind_dest_lvl = getBindLevel bind
do_pair (TB name level, rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (name, rhs'))
}
\end{code}
%************************************************************************
\subsection[FloatOutExpr]{Floating in expressions}
%* *
%************************************************************************
\begin{code}
floatExpr, floatRhs, floatCaseAlt
:: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatCaseAlt lvl arg
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatRhs lvl arg
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
if exprIsCheap arg' then
(fsa, floats, arg')
else
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatExpr _ (Var v) = (zeroStats, [], Var v)
floatExpr _ (Type ty) = (zeroStats, [], Type ty)
floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
floatExpr lvl (App e a)
= case (floatExpr lvl e) of { (fse, floats_e, e') ->
case (floatRhs lvl a) of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
floatExpr _ lam@(Lam _ _)
= let
(bndrs_w_lvls, body) = collectBinders lam
bndrs = [b | TB b _ <- bndrs_w_lvls]
lvls = [l | TB _ l <- bndrs_w_lvls]
partition_fn | all isTyVar bndrs = partitionByLevel
| otherwise = partitionByMajorLevel
in
case (floatExpr (last lvls) body) of { (fs, floats, body') ->
case (partition_fn (head lvls) floats) of { (floats', heres) ->
(add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
}}
floatExpr lvl (Note note@(SCC cc) expr)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
let
annotated_defns = annotate (dupifyCC cc) floating_defns
in
(fs, annotated_defns, Note note expr') }
where
annotate :: CostCentre -> FloatBinds -> FloatBinds
annotate dupd_cc defn_groups
= [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
where
ann_bind (NonRec binder rhs)
= NonRec binder (mkSCC dupd_cc rhs)
ann_bind (Rec pairs)
= Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
floatExpr _ (Note InlineMe expr)
= (zeroStats, [], Note InlineMe (unTag expr))
floatExpr lvl (Note note expr)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Note note expr') }
floatExpr lvl (Cast expr co)
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Cast expr' co) }
floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
| isUnLiftedType (idType bndr)
= case floatExpr lvl rhs of { (_, rhs_floats, rhs') ->
case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
(fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
floatExpr lvl (Let bind body)
= case (floatBind bind) of { (fsb, bind_floats) ->
case (floatExpr lvl body) of { (fse, body_floats, body') ->
(add_stats fsb fse,
bind_floats ++ body_floats,
body') }}
floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
= case floatExpr lvl scrut of { (fse, fde, scrut') ->
case floatList float_alt alts of { (fsa, fda, alts') ->
(add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
}}
where
float_alt (con, bs, rhs)
= case (floatCaseAlt case_lvl rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, [], [])
floatList f (a:as) = case f a of { (fs_a, binds_a, b) ->
case floatList f as of { (fs_as, binds_as, bs) ->
(fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
unTagBndr :: TaggedBndr tag -> CoreBndr
unTagBndr (TB b _) = b
unTag :: TaggedExpr tag -> CoreExpr
unTag (Var v) = Var v
unTag (Lit l) = Lit l
unTag (Type ty) = Type ty
unTag (Note n e) = Note n (unTag e)
unTag (App e1 e2) = App (unTag e1) (unTag e2)
unTag (Lam b e) = Lam (unTagBndr b) (unTag e)
unTag (Cast e co) = Cast (unTag e) co
unTag (Let (Rec prs) e) = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
unTag (Case e b ty alts) = Case (unTag e) (unTagBndr b) ty
[(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
\end{code}
%************************************************************************
%* *
\subsection{Utility bits for floating stats}
%* *
%************************************************************************
I didn't implement this with unboxed numbers. I don't want to be too
strict in this stuff, as it is rarely turned on. (WDP 95/09)
\begin{code}
data FloatStats
= FlS Int
Int
Int
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)
zeroStats :: FloatStats
zeroStats = FlS 0 0 0
sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
= FlS (a1 + a2) (b1 + b2) (c1 + c2)
add_to_stats :: FloatStats -> [(Level, Bind CoreBndr)] -> FloatStats
add_to_stats (FlS a b c) floats
= FlS (a + length top_floats) (b + length other_floats) (c + 1)
where
(top_floats, other_floats) = partition to_very_top floats
to_very_top (my_lvl, _) = isTopLvl my_lvl
\end{code}
%************************************************************************
%* *
\subsection{Utility bits for floating}
%* *
%************************************************************************
\begin{code}
getBindLevel :: Bind (TaggedBndr Level) -> Level
getBindLevel (NonRec (TB _ lvl) _) = lvl
getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
getBindLevel (Rec []) = panic "getBindLevel Rec []"
\end{code}
\begin{code}
partitionByMajorLevel, partitionByLevel
:: Level
-> FloatBinds
-> (FloatBinds,
FloatBinds)
partitionByMajorLevel ctxt_lvl defns
= partition float_further defns
where
float_further (my_lvl, _) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
partitionByLevel ctxt_lvl defns
= partition float_further defns
where
float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
\end{code}
\begin{code}
floatsToBinds :: FloatBinds -> [CoreBind]
floatsToBinds floats = map snd floats
floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
floatsToBindPairs floats = concat (map mk_pairs floats)
where
mk_pairs (_, Rec pairs) = pairs
mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
install :: FloatBinds -> CoreExpr -> CoreExpr
install defn_groups expr
= foldr install_group expr defn_groups
where
install_group (_, defns) body = Let defns body
\end{code}