%
% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
%
\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.

\begin{code}
module FloatOut ( floatOutwards ) where

import CoreSyn
import CoreUtils

import DynFlags	( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils		( dumpIfSet_dyn )
import CostCentre	( dupifyCC, CostCentre )
import Id		( Id, idType )
import Type		( isUnLiftedType )
import SetLevels	( Level(..), LevelledExpr, LevelledBind,
			  setLevels, ltMajLvl, ltLvl, isTopLvl )
import UniqSupply       ( UniqSupply )
import Data.List
import Outputable
import FastString
\end{code}

	-----------------
	Overall game plan
	-----------------

The Big Main Idea is:

  	To float out sub-expressions that can thereby get outside
	a non-one-shot value lambda, and hence may be shared.


To achieve this we may need to do two thing:

   a) Let-bind the sub-expression:

	f (g x)  ==>  let lvl = f (g x) in lvl

      Now we can float the binding for 'lvl'.  

   b) More than that, we may need to abstract wrt a type variable

	\x -> ... /\a -> let v = ...a... in ....

      Here the binding for v mentions 'a' but not 'x'.  So we
      abstract wrt 'a', to give this binding for 'v':

	    vp = /\a -> ...a...
	    v  = vp a

      Now the binding for vp can float out unimpeded.
      I can't remember why this case seemed important enough to
      deal with, but I certainly found cases where important floats
      didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.


Random comments
~~~~~~~~~~~~~~~

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

@
	\x y -> let t = x+x in ...
===>
	\x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
@
elEm2 x ys
  = elem' x ys
  where
    elem' _ []	= False
    elem' x (y:ys)	= x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
	\deq x ys -> let eq = eqFromEqDict deq in ...
@
vwhich might usefully be separated to
@
	\deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe.  We don't do this at the moment.

\begin{code}
type FloatBind     = (Level, CoreBind)	-- INVARIANT: a FloatBind is always lifted
type FloatBinds    = [FloatBind]	
\end{code}

%************************************************************************
%*									*
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
%*									*
%************************************************************************

\begin{code}
floatOutwards :: FloatOutSwitches
	      -> DynFlags
	      -> UniqSupply 
	      -> [CoreBind] -> IO [CoreBind]

floatOutwards float_sws dflags us pgm
  = do {
	let { annotated_w_levels = setLevels float_sws pgm us ;
	      (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
	    } ;

	dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
	          (vcat (map ppr annotated_w_levels));

	let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };

	dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
		(hcat [	int tlets,  ptext (sLit " Lets floated to top level; "),
			int ntlets, ptext (sLit " Lets floated elsewhere; from "),
			int lams,   ptext (sLit " Lambda groups")]);

	return (concat binds_s')
    }

floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
floatTopBind bind
  = case (floatBind bind) of { (fs, floats) ->
    (fs, floatsToBinds floats)
    }
\end{code}

%************************************************************************
%*									*
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
%*									*
%************************************************************************


\begin{code}
floatBind :: LevelledBind -> (FloatStats, FloatBinds)

floatBind (NonRec (TB name level) rhs)
  = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
    (fs, rhs_floats ++ [(level, NonRec name rhs')]) }

floatBind bind@(Rec pairs)
  = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
    let rhs_floats = concat rhss_floats in

    if not (isTopLvl bind_dest_lvl) then
	-- Find which bindings float out at least one lambda beyond this one
	-- These ones can't mention the binders, because they couldn't 
	-- be escaping a major level if so.
	-- The ones that are not going further can join the letrec;
	-- they may not be mutually recursive but the occurrence analyser will
	-- find that out.
	case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
	(sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
    else
	-- In a recursive binding, *destined for* the top level
	-- (only), the rhs floats may contain references to the 
	-- bound things.  For example
	--	f = ...(let v = ...f... in b) ...
	--  might get floated to
	--	v = ...f...
	--	f = ... b ...
	-- and hence we must (pessimistically) make all the floats recursive
	-- with the top binding.  Later dependency analysis will unravel it.
	--
	-- This can only happen for bindings destined for the top level,
	-- because only then will partitionByMajorLevel allow through a binding
	-- that only differs in its minor level
	(sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
    }
  where
    bind_dest_lvl = getBindLevel bind

    do_pair (TB name level, rhs)
      = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
	(fs, rhs_floats, (name, rhs'))
	}
\end{code}

%************************************************************************

\subsection[FloatOut-Expr]{Floating in expressions}
%*									*
%************************************************************************

\begin{code}
floatExpr, floatRhs, floatCaseAlt
	 :: Level
	 -> LevelledExpr
	 -> (FloatStats, FloatBinds, CoreExpr)

floatCaseAlt lvl arg	-- Used rec rhss, and case-alternative rhss
  = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
    case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
	-- Dump bindings that aren't going to escape from a lambda;
	-- in particular, we must dump the ones that are bound by 
	-- the rec or case alternative
    (fsa, floats', install heres arg') }}

floatRhs lvl arg	-- Used for nested non-rec rhss, and fn args
			-- See Note [Floating out of RHS]
  = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
    if exprIsCheap arg' then	
	(fsa, floats, arg')
    else
    case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
    (fsa, floats', install heres arg') }}

-- Note [Floating out of RHSs]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- Dump bindings that aren't going to escape from a lambda
-- This isn't a scoping issue (the binder isn't in scope in the RHS 
--	of a non-rec binding)
-- Rather, it is to avoid floating the x binding out of
--	f (let x = e in b)
-- unnecessarily.  But we first test for values or trival rhss,
-- because (in particular) we don't want to insert new bindings between
-- the "=" and the "\".  E.g.
--	f = \x -> let <bind> in <body>
-- We do not want
--	f = let <bind> in \x -> <body>
-- (a) The simplifier will immediately float it further out, so we may
--	as well do so right now; in general, keeping rhss as manifest 
--	values is good
-- (b) If a float-in pass follows immediately, it might add yet more
--	bindings just after the '='.  And some of them might (correctly)
--	be strict even though the 'let f' is lazy, because f, being a value,
--	gets its demand-info zapped by the simplifier.
--
-- We use exprIsCheap because that is also what's used by the simplifier
-- to decide whether to float a let out of a let

floatExpr _ (Var v)   = (zeroStats, [], Var v)
floatExpr _ (Type ty) = (zeroStats, [], Type ty)
floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
	  
floatExpr lvl (App e a)
  = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
    case (floatRhs lvl a) 	of { (fsa, floats_a, a') ->
    (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}

floatExpr _ lam@(Lam _ _)
  = let
	(bndrs_w_lvls, body) = collectBinders lam
	bndrs		     = [b | TB b _ <- bndrs_w_lvls]
	lvls		     = [l | TB _ l <- bndrs_w_lvls]

	-- For the all-tyvar case we are prepared to pull 
	-- the lets out, to implement the float-out-of-big-lambda
	-- transform; but otherwise we only float bindings that are
	-- going to escape a value lambda.
	-- In particular, for one-shot lambdas we don't float things
	-- out; we get no saving by so doing.
	partition_fn | all isTyVar bndrs = partitionByLevel
		     | otherwise	 = partitionByMajorLevel
    in
    case (floatExpr (last lvls) body) of { (fs, floats, body') ->

	-- Dump any bindings which absolutely cannot go any further
    case (partition_fn (head lvls) floats)	of { (floats', heres) ->

    (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
    }}

floatExpr lvl (Note note@(SCC cc) expr)
  = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
    let
	-- Annotate bindings floated outwards past an scc expression
	-- with the cc.  We mark that cc as "duplicated", though.

	annotated_defns = annotate (dupifyCC cc) floating_defns
    in
    (fs, annotated_defns, Note note expr') }
  where
    annotate :: CostCentre -> FloatBinds -> FloatBinds

    annotate dupd_cc defn_groups
      = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
      where
	ann_bind (NonRec binder rhs)
	  = NonRec binder (mkSCC dupd_cc rhs)

	ann_bind (Rec pairs)
	  = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]

floatExpr _ (Note InlineMe expr)	-- Other than SCCs
  = (zeroStats, [], Note InlineMe (unTag expr))
	-- Do no floating at all inside INLINE.
	-- The SetLevels pass did not clone the bindings, so it's
	-- unsafe to do any floating, even if we dump the results
	-- inside the Note (which is what we used to do).

floatExpr lvl (Note note expr)	-- Other than SCCs
  = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
    (fs, floating_defns, Note note expr') }

floatExpr lvl (Cast expr co)
  = case (floatExpr lvl expr)	of { (fs, floating_defns, expr') ->
    (fs, floating_defns, Cast expr' co) }

floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
  | isUnLiftedType (idType bndr)	-- Treat unlifted lets just like a case
				-- I.e. floatExpr for rhs, floatCaseAlt for body
  = case floatExpr lvl rhs	    of { (_, rhs_floats, rhs') ->
    case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
    (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}

floatExpr lvl (Let bind body)
  = case (floatBind bind)     of { (fsb, bind_floats) ->
    case (floatExpr lvl body) of { (fse, body_floats, body') ->
    (add_stats fsb fse,
     bind_floats ++ body_floats,
     body')  }}

floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
  = case floatExpr lvl scrut	of { (fse, fde, scrut') ->
    case floatList float_alt alts	of { (fsa, fda, alts')  ->
    (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
    }}
  where
	-- Use floatCaseAlt for the alternatives, so that we
	-- don't gratuitiously float bindings out of the RHSs
    float_alt (con, bs, rhs)
	= case (floatCaseAlt case_lvl rhs)	of { (fs, rhs_floats, rhs') ->
	  (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }


floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (zeroStats, [], [])
floatList f (a:as) = case f a		 of { (fs_a,  binds_a,  b)  ->
		     case floatList f as of { (fs_as, binds_as, bs) ->
		     (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}

unTagBndr :: TaggedBndr tag -> CoreBndr
unTagBndr (TB b _) = b

unTag :: TaggedExpr tag -> CoreExpr
unTag (Var v)  	  = Var v
unTag (Lit l)  	  = Lit l
unTag (Type ty)   = Type ty
unTag (Note n e)  = Note n (unTag e)
unTag (App e1 e2) = App (unTag e1) (unTag e2)
unTag (Lam b e)   = Lam (unTagBndr b) (unTag e)
unTag (Cast e co) = Cast (unTag e) co
unTag (Let (Rec prs) e)    = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
unTag (Case e b ty alts)   = Case (unTag e) (unTagBndr b) ty
			          [(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
\end{code}

%************************************************************************
%*									*
\subsection{Utility bits for floating stats}
%*									*
%************************************************************************

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)

\begin{code}
data FloatStats
  = FlS	Int  -- Number of top-floats * lambda groups they've been past
	Int  -- Number of non-top-floats * lambda groups they've been past
	Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a b c) = (a, b, c)

zeroStats :: FloatStats
zeroStats = FlS 0 0 0

sum_stats :: [FloatStats] -> FloatStats
sum_stats xs = foldr add_stats zeroStats xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
  = FlS (a1 + a2) (b1 + b2) (c1 + c2)

add_to_stats :: FloatStats -> [(Level, Bind CoreBndr)] -> FloatStats
add_to_stats (FlS a b c) floats
  = FlS (a + length top_floats) (b + length other_floats) (c + 1)
  where
    (top_floats, other_floats) = partition to_very_top floats

    to_very_top (my_lvl, _) = isTopLvl my_lvl
\end{code}


%************************************************************************
%*									*
\subsection{Utility bits for floating}
%*									*
%************************************************************************

\begin{code}
getBindLevel :: Bind (TaggedBndr Level) -> Level
getBindLevel (NonRec (TB _ lvl) _)       = lvl
getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
getBindLevel (Rec [])                    = panic "getBindLevel Rec []"
\end{code}

\begin{code}
partitionByMajorLevel, partitionByLevel
	:: Level		-- Partitioning level

	-> FloatBinds   	-- Defns to be divided into 2 piles...

	-> (FloatBinds,	-- Defns  with level strictly < partition level,
	    FloatBinds)	-- The rest


partitionByMajorLevel ctxt_lvl defns
  = partition float_further defns
  where
	-- Float it if we escape a value lambda, or if we get to the top level
    float_further (my_lvl, _) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
	-- The isTopLvl part says that if we can get to the top level, say "yes" anyway
	-- This means that 
	--	x = f e
	-- transforms to 
	--    lvl = e
	--    x = f lvl
	-- which is as it should be

partitionByLevel ctxt_lvl defns
  = partition float_further defns
  where
    float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
\end{code}

\begin{code}
floatsToBinds :: FloatBinds -> [CoreBind]
floatsToBinds floats = map snd floats

floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]

floatsToBindPairs floats = concat (map mk_pairs floats)
  where
   mk_pairs (_, Rec pairs)         = pairs
   mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]

install :: FloatBinds -> CoreExpr -> CoreExpr

install defn_groups expr
  = foldr install_group expr defn_groups
  where
    install_group (_, defns) body = Let defns body
\end{code}