% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.

module FloatOut ( floatOutwards ) where

import CoreSyn
import CoreUtils

import DynFlags	( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils		( dumpIfSet_dyn )
import CostCentre	( dupifyCC, CostCentre )
import Id		( Id, idType )
import Type		( isUnLiftedType )
import SetLevels	( Level(..), LevelledExpr, LevelledBind,
			  setLevels, isTopLvl, tOP_LEVEL )
import UniqSupply       ( UniqSupply )
import Bag
import Util
import Maybes
import UniqFM
import Outputable
import FastString

	Overall game plan

The Big Main Idea is:

  	To float out sub-expressions that can thereby get outside
	a non-one-shot value lambda, and hence may be shared.

To achieve this we may need to do two thing:

   a) Let-bind the sub-expression:

	f (g x)  ==>  let lvl = f (g x) in lvl

      Now we can float the binding for 'lvl'.  

   b) More than that, we may need to abstract wrt a type variable

	\x -> ... /\a -> let v = ...a... in ....

      Here the binding for v mentions 'a' but not 'x'.  So we
      abstract wrt 'a', to give this binding for 'v':

	    vp = /\a -> ...a...
	    v  = vp a

      Now the binding for vp can float out unimpeded.
      I can't remember why this case seemed important enough to
      deal with, but I certainly found cases where important floats
      didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.

Random comments

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

	\x y -> let t = x+x in ...
	\x -> let t = x+x in \y -> ...
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
elEm2 x ys
  = elem' x ys
    elem' _ []	= False
    elem' x (y:ys)	= x==y || elem' x ys
It turns out that this generates a subexpression of the form
	\deq x ys -> let eq = eqFromEqDict deq in ...
vwhich might usefully be separated to
	\deq -> let eq = eqFromEqDict deq in \xy -> ...
Well, maybe.  We don't do this at the moment.

%*									*
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
%*									*

floatOutwards :: FloatOutSwitches
	      -> DynFlags
	      -> UniqSupply 
	      -> [CoreBind] -> IO [CoreBind]

floatOutwards float_sws dflags us pgm
  = do {
	let { annotated_w_levels = setLevels float_sws pgm us ;
	      (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
	    } ;

	dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
	          (vcat (map ppr annotated_w_levels));

	let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };

	dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
		(hcat [	int tlets,  ptext (sLit " Lets floated to top level; "),
			int ntlets, ptext (sLit " Lets floated elsewhere; from "),
			int lams,   ptext (sLit " Lambda groups")]);

	return (concat binds_s')

floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
floatTopBind bind
  = case (floatBind bind) of { (fs, floats) ->
    (fs, bagToList (flattenFloats floats))

%*									*
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
%*									*

floatBind :: LevelledBind -> (FloatStats, FloatBinds)

floatBind (NonRec (TB name level) rhs)
  = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
    (fs, rhs_floats `plusFloats` unitFloat level (NonRec name rhs')) }

floatBind bind@(Rec pairs)
  = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
    let rhs_floats = foldr1 plusFloats rhss_floats in

    if not (isTopLvl bind_dest_lvl) then
	-- Find which bindings float out at least one lambda beyond this one
	-- These ones can't mention the binders, because they couldn't 
	-- be escaping a major level if so.
	-- The ones that are not going further can join the letrec;
	-- they may not be mutually recursive but the occurrence analyser will
	-- find that out.
	case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
	(sum_stats fss,
      floats' `plusFloats` unitFloat bind_dest_lvl
                            (Rec (floatsToBindPairs heres new_pairs))) }
	-- In a recursive binding, *destined for* the top level
	-- (only), the rhs floats may contain references to the 
	-- bound things.  For example
	--	f = ...(let v = ...f... in b) ...
	--  might get floated to
	--	v = ...f...
	--	f = ... b ...
	-- and hence we must (pessimistically) make all the floats recursive
	-- with the top binding.  Later dependency analysis will unravel it.
	-- This can only happen for bindings destined for the top level,
	-- because only then will partitionByMajorLevel allow through a binding
	-- that only differs in its minor level
	(sum_stats fss, unitFloat tOP_LEVEL
                       (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
    bind_dest_lvl = getBindLevel bind

    do_pair (TB name level, rhs)
      = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
	(fs, rhs_floats, (name, rhs'))


\subsection[FloatOut-Expr]{Floating in expressions}
%*									*

floatExpr, floatRhs, floatCaseAlt
	 :: Level
	 -> LevelledExpr
	 -> (FloatStats, FloatBinds, CoreExpr)

floatCaseAlt lvl arg	-- Used rec rhss, and case-alternative rhss
  = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
    case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
	-- Dump bindings that aren't going to escape from a lambda;
	-- in particular, we must dump the ones that are bound by 
	-- the rec or case alternative
    (fsa, floats', install heres arg') }}

floatRhs lvl arg	-- Used for nested non-rec rhss, and fn args
			-- See Note [Floating out of RHS]
  = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
    if exprIsCheap arg' then	
	(fsa, floats, arg')
    case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
    (fsa, floats', install heres arg') }}

-- Note [Floating out of RHSs]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- Dump bindings that aren't going to escape from a lambda
-- This isn't a scoping issue (the binder isn't in scope in the RHS 
--	of a non-rec binding)
-- Rather, it is to avoid floating the x binding out of
--	f (let x = e in b)
-- unnecessarily.  But we first test for values or trival rhss,
-- because (in particular) we don't want to insert new bindings between
-- the "=" and the "\".  E.g.
--	f = \x -> let <bind> in <body>
-- We do not want
--	f = let <bind> in \x -> <body>
-- (a) The simplifier will immediately float it further out, so we may
--	as well do so right now; in general, keeping rhss as manifest 
--	values is good
-- (b) If a float-in pass follows immediately, it might add yet more
--	bindings just after the '='.  And some of them might (correctly)
--	be strict even though the 'let f' is lazy, because f, being a value,
--	gets its demand-info zapped by the simplifier.
-- We use exprIsCheap because that is also what's used by the simplifier
-- to decide whether to float a let out of a let

floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
floatExpr lvl (App e a)
  = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
    case (floatRhs lvl a) 	of { (fsa, floats_a, a') ->
    (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}

floatExpr _ lam@(Lam _ _)
  = let
	(bndrs_w_lvls, body) = collectBinders lam
	bndrs		     = [b | TB b _ <- bndrs_w_lvls]
	lvls		     = [l | TB _ l <- bndrs_w_lvls]

	-- For the all-tyvar case we are prepared to pull 
	-- the lets out, to implement the float-out-of-big-lambda
	-- transform; but otherwise we only float bindings that are
	-- going to escape a value lambda.
	-- In particular, for one-shot lambdas we don't float things
	-- out; we get no saving by so doing.
	partition_fn | all isTyVar bndrs = partitionByLevel
		     | otherwise	 = partitionByMajorLevel
    case (floatExpr (last lvls) body) of { (fs, floats, body') ->

	-- Dump any bindings which absolutely cannot go any further
    case (partition_fn (head lvls) floats)	of { (floats', heres) ->

    (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))

floatExpr lvl (Note note@(SCC cc) expr)
  = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
	-- Annotate bindings floated outwards past an scc expression
	-- with the cc.  We mark that cc as "duplicated", though.

	annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
    (fs, annotated_defns, Note note expr') }

floatExpr _ (Note InlineMe expr)	-- Other than SCCs
  = (zeroStats, emptyFloats, Note InlineMe (unTag expr))
	-- Do no floating at all inside INLINE.
	-- The SetLevels pass did not clone the bindings, so it's
	-- unsafe to do any floating, even if we dump the results
	-- inside the Note (which is what we used to do).

floatExpr lvl (Note note expr)	-- Other than SCCs
  = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
    (fs, floating_defns, Note note expr') }

floatExpr lvl (Cast expr co)
  = case (floatExpr lvl expr)	of { (fs, floating_defns, expr') ->
    (fs, floating_defns, Cast expr' co) }

floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
  | isUnLiftedType (idType bndr)	-- Treat unlifted lets just like a case
				-- I.e. floatExpr for rhs, floatCaseAlt for body
  = case floatExpr lvl rhs	    of { (_, rhs_floats, rhs') ->
    case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
    (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}

floatExpr lvl (Let bind body)
  = case (floatBind bind)     of { (fsb, bind_floats) ->
    case (floatExpr lvl body) of { (fse, body_floats, body') ->
    (add_stats fsb fse,
     bind_floats `plusFloats` body_floats,
     body')  }}

floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
  = case floatExpr lvl scrut	of { (fse, fde, scrut') ->
    case floatList float_alt alts	of { (fsa, fda, alts')  ->
    (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
	-- Use floatCaseAlt for the alternatives, so that we
	-- don't gratuitiously float bindings out of the RHSs
    float_alt (con, bs, rhs)
	= case (floatCaseAlt case_lvl rhs)	of { (fs, rhs_floats, rhs') ->
	  (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }

floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, emptyFloats, [])
floatList f (a:as) = case f a		 of { (fs_a,  binds_a,  b)  ->
		     case floatList f as of { (fs_as, binds_as, bs) ->
		     (fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }}

getBindLevel :: Bind (TaggedBndr Level) -> Level
getBindLevel (NonRec (TB _ lvl) _)       = lvl
getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
getBindLevel (Rec [])                    = panic "getBindLevel Rec []"

unTagBndr :: TaggedBndr tag -> CoreBndr
unTagBndr (TB b _) = b

unTag :: TaggedExpr tag -> CoreExpr
unTag (Var v)  	  = Var v
unTag (Lit l)  	  = Lit l
unTag (Type ty)   = Type ty
unTag (Note n e)  = Note n (unTag e)
unTag (App e1 e2) = App (unTag e1) (unTag e2)
unTag (Lam b e)   = Lam (unTagBndr b) (unTag e)
unTag (Cast e co) = Cast (unTag e) co
unTag (Let (Rec prs) e)    = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
unTag (Case e b ty alts)   = Case (unTag e) (unTagBndr b) ty
			          [(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]

%*									*
\subsection{Utility bits for floating stats}
%*									*

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)

data FloatStats
  = FlS	Int  -- Number of top-floats * lambda groups they've been past
	Int  -- Number of non-top-floats * lambda groups they've been past
	Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)

zeroStats :: FloatStats
zeroStats = FlS 0 0 0

sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
  = FlS (a1 + a2) (b1 + b2) (c1 + c2)

add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a b c) (FB tops others)
  = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)

%*									*
\subsection{Utility bits for floating}
%*									*

Note [Representation of FloatBinds]
The FloatBinds types is somewhat important.  We can get very large numbers
of floating bindings, often all destined for the top level.  A typical example
is     x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.

The trouble is that
  (a) we partition these floating bindings *at every binding site*
  (b) SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!

That is why MajorEnv is represented as a finite map.

We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see

type FloatBind = CoreBind      -- INVARIANT: a FloatBind is always lifted

data FloatBinds  = FB !(Bag FloatBind)         -- Destined for top level
                     !MajorEnv                 -- Levels other than top
     -- See Note [Representation of FloatBinds]

type MajorEnv = UniqFM MinorEnv                        -- Keyed by major level
type MinorEnv = UniqFM (Bag FloatBind)         -- Keyed by minor level

flattenFloats :: FloatBinds -> Bag FloatBind
flattenFloats (FB tops others) = tops `unionBags` flattenMajor others

flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag

flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = foldUFM unionBags emptyBag

emptyFloats :: FloatBinds
emptyFloats = FB emptyBag emptyUFM

unitFloat :: Level -> FloatBind -> FloatBinds
unitFloat InlineCtxt b = FB (unitBag b) emptyUFM
unitFloat lvl@(Level major minor) b
  | isTopLvl lvl = FB (unitBag b) emptyUFM
  | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))

plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)

plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = plusUFM_C plusMinor

plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = plusUFM_C unionBags

floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
floatsToBindPairs floats binds = foldrBag add binds floats
   add (Rec pairs)         binds = pairs ++ binds
   add (NonRec binder rhs) binds = (binder,rhs) : binds

install :: Bag FloatBind -> CoreExpr -> CoreExpr
install defn_groups expr
  = foldrBag install_group expr defn_groups
    install_group defns body = Let defns body

partitionByMajorLevel, partitionByLevel
       :: Level                -- Partitioning level
       -> FloatBinds           -- Defns to be divided into 2 piles...
       -> (FloatBinds,         -- Defns  with level strictly < partition level,
           Bag FloatBind)      -- The rest

--      ---- partitionByMajorLevel ----
-- Float it if we escape a value lambda, *or* if we get to the top level
-- If we can get to the top level, say "yes" anyway. This means that
--     x = f e
-- transforms to
--    lvl = e
--    x = f lvl
-- which is as it should be

partitionByMajorLevel InlineCtxt (FB tops defns)
  = (FB tops emptyUFM, flattenMajor defns)

partitionByMajorLevel (Level major _) (FB tops defns)
  = (FB tops outer, heres `unionBags` flattenMajor inner)
    (outer, mb_heres, inner) = splitUFM defns major
    heres = case mb_heres of
               Nothing -> emptyBag
               Just h  -> flattenMinor h

partitionByLevel InlineCtxt (FB tops defns)
  = (FB tops emptyUFM, flattenMajor defns)

partitionByLevel (Level major minor) (FB tops defns)
  = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
     here_min `unionBags` flattenMinor inner_min
              `unionBags` flattenMajor inner_maj)

    (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
    (outer_min, mb_here_min, inner_min) = case mb_here_maj of
                                            Nothing -> (emptyUFM, Nothing, emptyUFM)
                                            Just min_defns -> splitUFM min_defns minor
    here_min = mb_here_min `orElse` emptyBag

wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
wrapCostCentre cc (FB tops defns)
  = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
    wrap_defns = mapBag wrap_one
    wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
    wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)