-- | Implements a selective lambda lifter, running late in the optimisation
-- pipeline.
--
-- If you are interested in the cost model that is employed to decide whether
-- to lift a binding or not, look at "GHC.Stg.Lift.Analysis".
-- "GHC.Stg.Lift.Monad" contains the transformation monad that hides away some
-- plumbing of the transformation.
module GHC.Stg.Lift
   (
    -- * Late lambda lifting in STG
    -- $note
   StgLiftConfig (..),
   stgLiftLams
   )
where

import GHC.Prelude

import GHC.Types.Basic
import GHC.Types.Id
import GHC.Stg.FVs ( annBindingFreeVars )
import GHC.Stg.Lift.Config
import GHC.Stg.Lift.Analysis
import GHC.Stg.Lift.Monad
import GHC.Stg.Syntax
import GHC.Unit.Module (Module)
import GHC.Types.Unique.Supply
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Types.Var.Set
import Control.Monad ( when )
import Data.Maybe ( isNothing )

-- Note [Late lambda lifting in STG]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- $note
-- See also the <https://gitlab.haskell.org/ghc/ghc/wikis/late-lam-lift wiki page>
-- and #9476.
--
-- The basic idea behind lambda lifting is to turn locally defined functions
-- into top-level functions. Free variables are then passed as additional
-- arguments at *call sites* instead of having a closure allocated for them at
-- *definition site*. Example:
--
-- @
--    let x = ...; y = ... in
--    let f = {x y} \a -> a + x + y in
--    let g = {f x} \b -> f b + x in
--    g 5
-- @
--
-- Lambda lifting @f@ would
--
--   1. Turn @f@'s free variables into formal parameters
--   2. Update @f@'s call site within @g@ to @f x y b@
--   3. Update @g@'s closure: Add @y@ as an additional free variable, while
--      removing @f@, because @f@ no longer allocates and can be floated to
--      top-level.
--   4. Actually float the binding of @f@ to top-level, eliminating the @let@
--      in the process.
--
-- This results in the following program (with free var annotations):
--
-- @
--    f x y a = a + x + y;
--    let x = ...; y = ... in
--    let g = {x y} \b -> f x y b + x in
--    g 5
-- @
--
-- This optimisation is all about lifting only when it is beneficial to do so.
-- The above seems like a worthwhile lift, judging from heap allocation:
-- We eliminate @f@'s closure, saving to allocate a closure with 2 words, while
-- not changing the size of @g@'s closure.
--
-- You can probably sense that there's some kind of cost model at play here.
-- And you are right! But we also employ a couple of other heuristics for the
-- lifting decision which are outlined in "GHC.Stg.Lift.Analysis#when".
--
-- The transformation is done in "GHC.Stg.Lift", which calls out to
-- 'GHC.Stg.Lift.Analysis.goodToLift' for its lifting decision.  It relies on
-- "GHC.Stg.Lift.Monad", which abstracts some subtle STG invariants into a
-- monadic substrate.
--
-- Suffice to say: We trade heap allocation for stack allocation.
-- The additional arguments have to passed on the stack (or in registers,
-- depending on architecture) every time we call the function to save a single
-- heap allocation when entering the let binding. Nofib suggests a mean
-- improvement of about 1% for this pass, so it seems like a worthwhile thing to
-- do. Compile-times went up by 0.6%, so all in all a very modest change.
--
-- For a concrete example, look at @spectral/atom@. There's a call to 'zipWith'
-- that is ultimately compiled to something like this
-- (module desugaring/lowering to actual STG):
--
-- @
--    propagate dt = ...;
--    runExperiment ... =
--      let xs = ... in
--      let ys = ... in
--      let go = {dt go} \xs ys -> case (xs, ys) of
--            ([], []) -> []
--            (x:xs', y:ys') -> propagate dt x y : go xs' ys'
--      in go xs ys
-- @
--
-- This will lambda lift @go@ to top-level, speeding up the resulting program
-- by roughly one percent:
--
-- @
--    propagate dt = ...;
--    go dt xs ys = case (xs, ys) of
--      ([], []) -> []
--      (x:xs', y:ys') -> propagate dt x y : go dt xs' ys'
--    runExperiment ... =
--      let xs = ... in
--      let ys = ... in
--      in go dt xs ys
-- @



-- | Lambda lifts bindings to top-level deemed worth lifting (see 'goodToLift').
--
-- (Mostly) textbook instance of the lambda lifting transformation, selecting
-- which bindings to lambda lift by consulting 'goodToLift'.
stgLiftLams :: Module -> StgLiftConfig -> UniqSupply -> [InStgTopBinding] -> [OutStgTopBinding]
stgLiftLams :: Module
-> StgLiftConfig
-> UniqSupply
-> [InStgTopBinding]
-> [InStgTopBinding]
stgLiftLams Module
this_mod StgLiftConfig
cfg UniqSupply
us = StgLiftConfig -> UniqSupply -> LiftM () -> [InStgTopBinding]
runLiftM StgLiftConfig
cfg UniqSupply
us (LiftM () -> [InStgTopBinding])
-> ([InStgTopBinding] -> LiftM ())
-> [InStgTopBinding]
-> [InStgTopBinding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (InStgTopBinding -> LiftM () -> LiftM ())
-> LiftM () -> [InStgTopBinding] -> LiftM ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Module -> InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl Module
this_mod) (() -> LiftM ()
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

liftTopLvl :: Module -> InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl :: Module -> InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl Module
_ (StgTopStringLit Id
bndr ByteString
lit) LiftM ()
rest = Id -> (Id -> LiftM ()) -> LiftM ()
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr ((Id -> LiftM ()) -> LiftM ()) -> (Id -> LiftM ()) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \Id
bndr' -> do
  Id -> ByteString -> LiftM ()
addTopStringLit Id
bndr' ByteString
lit
  LiftM ()
rest
liftTopLvl Module
this_mod (StgTopLifted GenStgBinding 'Vanilla
bind) LiftM ()
rest = do
  let is_rec :: Bool
is_rec = RecFlag -> Bool
isRec (RecFlag -> Bool) -> RecFlag -> Bool
forall a b. (a -> b) -> a -> b
$ (RecFlag, [(Id, GenStgRhs 'Vanilla)]) -> RecFlag
forall a b. (a, b) -> a
fst ((RecFlag, [(Id, GenStgRhs 'Vanilla)]) -> RecFlag)
-> (RecFlag, [(Id, GenStgRhs 'Vanilla)]) -> RecFlag
forall a b. (a -> b) -> a -> b
$ GenStgBinding 'Vanilla
-> (RecFlag, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
  Bool -> LiftM () -> LiftM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_rec LiftM ()
startBindingGroup
  let bind_w_fvs :: CgStgBinding
bind_w_fvs = Module -> GenStgBinding 'Vanilla -> CgStgBinding
annBindingFreeVars Module
this_mod GenStgBinding 'Vanilla
bind
  TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM ())
-> LiftM ()
forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
TopLevel (CgStgBinding -> LlStgBinding
tagSkeletonTopBind CgStgBinding
bind_w_fvs) Skeleton
NilSk ((Maybe (GenStgBinding 'Vanilla) -> LiftM ()) -> LiftM ())
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM ()) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
    -- We signal lifting of a binding through returning Nothing.
    -- Should never happen for a top-level binding, though, since we are already
    -- at top-level.
    case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
      Maybe (GenStgBinding 'Vanilla)
Nothing -> String -> SDoc -> LiftM ()
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"StgLiftLams" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Lifted top-level binding")
      Just GenStgBinding 'Vanilla
bind' -> GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding GenStgBinding 'Vanilla
bind'
    Bool -> LiftM () -> LiftM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_rec LiftM ()
endBindingGroup
    LiftM ()
rest

withLiftedBind
  :: TopLevelFlag
  -> LlStgBinding
  -> Skeleton
  -> (Maybe OutStgBinding -> LiftM a)
  -> LiftM a
withLiftedBind :: forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
top_lvl LlStgBinding
bind Skeleton
scope Maybe (GenStgBinding 'Vanilla) -> LiftM a
k
  = TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a)
-> LiftM a
forall a.
TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a)
-> LiftM a
withLiftedBindPairs TopLevelFlag
top_lvl RecFlag
rec [(BinderInfo, LlStgRhs)]
[(BinderP 'LiftLams, LlStgRhs)]
pairs Skeleton
scope (Maybe (GenStgBinding 'Vanilla) -> LiftM a
k (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> (Maybe [(Id, GenStgRhs 'Vanilla)]
    -> Maybe (GenStgBinding 'Vanilla))
-> Maybe [(Id, GenStgRhs 'Vanilla)]
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(Id, GenStgRhs 'Vanilla)] -> GenStgBinding 'Vanilla)
-> Maybe [(Id, GenStgRhs 'Vanilla)]
-> Maybe (GenStgBinding 'Vanilla)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (RecFlag
-> [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
-> GenStgBinding 'Vanilla
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
rec))
  where
    (RecFlag
rec, [(BinderP 'LiftLams, LlStgRhs)]
pairs) = LlStgBinding -> (RecFlag, [(BinderP 'LiftLams, LlStgRhs)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding LlStgBinding
bind

withLiftedBindPairs
  :: TopLevelFlag
  -> RecFlag
  -> [(BinderInfo, LlStgRhs)]
  -> Skeleton
  -> (Maybe [(Id, OutStgRhs)] -> LiftM a)
  -> LiftM a
withLiftedBindPairs :: forall a.
TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a)
-> LiftM a
withLiftedBindPairs TopLevelFlag
top RecFlag
rec [(BinderInfo, LlStgRhs)]
pairs Skeleton
scope Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a
k = do
  let ([BinderInfo]
infos, [LlStgRhs]
rhss) = [(BinderInfo, LlStgRhs)] -> ([BinderInfo], [LlStgRhs])
forall a b. [(a, b)] -> ([a], [b])
unzip [(BinderInfo, LlStgRhs)]
pairs
  let bndrs :: [Id]
bndrs = (BinderInfo -> Id) -> [BinderInfo] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderInfo]
infos
  DIdSet -> DIdSet
expander <- LiftM (DIdSet -> DIdSet)
liftedIdsExpander
  StgLiftConfig
cfg <- LiftM StgLiftConfig
getConfig
  case StgLiftConfig
-> TopLevelFlag
-> RecFlag
-> (DIdSet -> DIdSet)
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> Maybe DIdSet
goodToLift StgLiftConfig
cfg TopLevelFlag
top RecFlag
rec DIdSet -> DIdSet
expander [(BinderInfo, LlStgRhs)]
pairs Skeleton
scope of
    -- @abs_ids@ is the set of all variables that need to become parameters.
    Just DIdSet
abs_ids -> DIdSet -> [Id] -> ([Id] -> LiftM a) -> LiftM a
forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids [Id]
bndrs (([Id] -> LiftM a) -> LiftM a) -> ([Id] -> LiftM a) -> LiftM a
forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
      -- Within this block, all binders in @bndrs@ will be noted as lifted, so
      -- that the return value of @liftedIdsExpander@ in this context will also
      -- expand the bindings in @bndrs@ to their free variables.
      -- Now we can recurse into the RHSs and see if we can lift any further
      -- bindings. We pass the set of expanded free variables (thus OutIds) on
      -- to @liftRhs@ so that it can add them as parameter binders.
      Bool -> LiftM () -> LiftM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RecFlag -> Bool
isRec RecFlag
rec) LiftM ()
startBindingGroup
      [GenStgRhs 'Vanilla]
rhss' <- (LlStgRhs -> LiftM (GenStgRhs 'Vanilla))
-> [LlStgRhs] -> LiftM [GenStgRhs 'Vanilla]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Maybe DIdSet -> LlStgRhs -> LiftM (GenStgRhs 'Vanilla)
liftRhs (DIdSet -> Maybe DIdSet
forall a. a -> Maybe a
Just DIdSet
abs_ids)) [LlStgRhs]
rhss
      let pairs' :: [(Id, GenStgRhs 'Vanilla)]
pairs' = [Id] -> [GenStgRhs 'Vanilla] -> [(Id, GenStgRhs 'Vanilla)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bndrs' [GenStgRhs 'Vanilla]
rhss'
      GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding (RecFlag
-> [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
-> GenStgBinding 'Vanilla
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
rec [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs')
      Bool -> LiftM () -> LiftM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RecFlag -> Bool
isRec RecFlag
rec) LiftM ()
endBindingGroup
      Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a
k Maybe [(Id, GenStgRhs 'Vanilla)]
forall a. Maybe a
Nothing
    Maybe DIdSet
Nothing -> [Id] -> ([Id] -> LiftM a) -> LiftM a
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs [Id]
bndrs (([Id] -> LiftM a) -> LiftM a) -> ([Id] -> LiftM a) -> LiftM a
forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
      -- Don't lift the current binding, but possibly some bindings in their
      -- RHSs.
      [GenStgRhs 'Vanilla]
rhss' <- (LlStgRhs -> LiftM (GenStgRhs 'Vanilla))
-> [LlStgRhs] -> LiftM [GenStgRhs 'Vanilla]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Maybe DIdSet -> LlStgRhs -> LiftM (GenStgRhs 'Vanilla)
liftRhs Maybe DIdSet
forall a. Maybe a
Nothing) [LlStgRhs]
rhss
      let pairs' :: [(Id, GenStgRhs 'Vanilla)]
pairs' = [Id] -> [GenStgRhs 'Vanilla] -> [(Id, GenStgRhs 'Vanilla)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bndrs' [GenStgRhs 'Vanilla]
rhss'
      Maybe [(Id, GenStgRhs 'Vanilla)] -> LiftM a
k ([(Id, GenStgRhs 'Vanilla)] -> Maybe [(Id, GenStgRhs 'Vanilla)]
forall a. a -> Maybe a
Just [(Id, GenStgRhs 'Vanilla)]
pairs')

liftRhs
  :: Maybe (DIdSet)
  -- ^ @Just former_fvs@ <=> this RHS was lifted and we have to add @former_fvs@
  -- as lambda binders, discarding all free vars.
  -> LlStgRhs
  -> LiftM OutStgRhs
liftRhs :: Maybe DIdSet -> LlStgRhs -> LiftM (GenStgRhs 'Vanilla)
liftRhs Maybe DIdSet
mb_former_fvs rhs :: LlStgRhs
rhs@(StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mn [StgTickish]
ts [StgArg]
args Type
typ)
  = Bool
-> SDoc -> LiftM (GenStgRhs 'Vanilla) -> LiftM (GenStgRhs 'Vanilla)
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Maybe DIdSet -> Bool
forall a. Maybe a -> Bool
isNothing Maybe DIdSet
mb_former_fvs)
              (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Should never lift a constructor"
               SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ StgPprOpts -> LlStgRhs -> SDoc
forall (pass :: StgPass).
OutputablePass pass =>
StgPprOpts -> GenStgRhs pass -> SDoc
pprStgRhs StgPprOpts
panicStgPprOpts LlStgRhs
rhs) (LiftM (GenStgRhs 'Vanilla) -> LiftM (GenStgRhs 'Vanilla))
-> LiftM (GenStgRhs 'Vanilla) -> LiftM (GenStgRhs 'Vanilla)
forall a b. (a -> b) -> a -> b
$
    CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> Type
-> GenStgRhs 'Vanilla
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> Type
-> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mn [StgTickish]
ts ([StgArg] -> Type -> GenStgRhs 'Vanilla)
-> LiftM [StgArg] -> LiftM (Type -> GenStgRhs 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StgArg -> LiftM StgArg) -> [StgArg] -> LiftM [StgArg]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args LiftM (Type -> GenStgRhs 'Vanilla)
-> LiftM Type -> LiftM (GenStgRhs 'Vanilla)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftM Type
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
liftRhs Maybe DIdSet
Nothing (StgRhsClosure XRhsClosure 'LiftLams
_ CostCentreStack
ccs UpdateFlag
upd [BinderP 'LiftLams]
infos GenStgExpr 'LiftLams
body Type
typ) =
  -- This RHS wasn't lifted.
  [Id]
-> ([Id] -> LiftM (GenStgRhs 'Vanilla))
-> LiftM (GenStgRhs 'Vanilla)
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs ((BinderInfo -> Id) -> [BinderInfo] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderInfo]
[BinderP 'LiftLams]
infos) (([Id] -> LiftM (GenStgRhs 'Vanilla))
 -> LiftM (GenStgRhs 'Vanilla))
-> ([Id] -> LiftM (GenStgRhs 'Vanilla))
-> LiftM (GenStgRhs 'Vanilla)
forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' ->
    XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> GenStgRhs 'Vanilla
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
NoExtFieldSilent
noExtFieldSilent CostCentreStack
ccs UpdateFlag
upd [Id]
[BinderP 'Vanilla]
bndrs' (GenStgExpr 'Vanilla -> Type -> GenStgRhs 'Vanilla)
-> LiftM (GenStgExpr 'Vanilla)
-> LiftM (Type -> GenStgRhs 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
body LiftM (Type -> GenStgRhs 'Vanilla)
-> LiftM Type -> LiftM (GenStgRhs 'Vanilla)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftM Type
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
liftRhs (Just DIdSet
former_fvs) (StgRhsClosure XRhsClosure 'LiftLams
_ CostCentreStack
ccs UpdateFlag
upd [BinderP 'LiftLams]
infos GenStgExpr 'LiftLams
body Type
typ) =
  -- This RHS was lifted. Insert extra binders for @former_fvs@.
  [Id]
-> ([Id] -> LiftM (GenStgRhs 'Vanilla))
-> LiftM (GenStgRhs 'Vanilla)
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs ((BinderInfo -> Id) -> [BinderInfo] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderInfo]
[BinderP 'LiftLams]
infos) (([Id] -> LiftM (GenStgRhs 'Vanilla))
 -> LiftM (GenStgRhs 'Vanilla))
-> ([Id] -> LiftM (GenStgRhs 'Vanilla))
-> LiftM (GenStgRhs 'Vanilla)
forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
    let bndrs'' :: [Id]
bndrs'' = DIdSet -> [Id]
dVarSetElems DIdSet
former_fvs [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id]
bndrs'
    XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> GenStgRhs 'Vanilla
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
NoExtFieldSilent
noExtFieldSilent CostCentreStack
ccs UpdateFlag
upd [Id]
[BinderP 'Vanilla]
bndrs'' (GenStgExpr 'Vanilla -> Type -> GenStgRhs 'Vanilla)
-> LiftM (GenStgExpr 'Vanilla)
-> LiftM (Type -> GenStgRhs 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
body LiftM (Type -> GenStgRhs 'Vanilla)
-> LiftM Type -> LiftM (GenStgRhs 'Vanilla)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftM Type
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ

liftArgs :: InStgArg -> LiftM OutStgArg
liftArgs :: StgArg -> LiftM StgArg
liftArgs a :: StgArg
a@(StgLitArg Literal
_) = StgArg -> LiftM StgArg
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StgArg
a
liftArgs (StgVarArg Id
occ) = do
  LiftM Bool -> SDoc -> LiftM ()
forall (m :: * -> *).
(HasCallStack, Monad m) =>
m Bool -> SDoc -> m ()
assertPprM (Bool -> Bool
not (Bool -> Bool) -> LiftM Bool -> LiftM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> LiftM Bool
isLifted Id
occ) (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"StgArgs should never be lifted" SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
occ)
  Id -> StgArg
StgVarArg (Id -> StgArg) -> LiftM Id -> LiftM StgArg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> LiftM Id
substOcc Id
occ

liftExpr :: LlStgExpr -> LiftM OutStgExpr
liftExpr :: GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr (StgLit Literal
lit) = GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Literal -> GenStgExpr 'Vanilla
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
liftExpr (StgTick StgTickish
t GenStgExpr 'LiftLams
e) = StgTickish -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> LiftM (GenStgExpr 'Vanilla) -> LiftM (GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
e
liftExpr (StgApp Id
f [StgArg]
args) = do
  Id
f' <- Id -> LiftM Id
substOcc Id
f
  [StgArg]
args' <- (StgArg -> LiftM StgArg) -> [StgArg] -> LiftM [StgArg]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args
  [Id]
fvs' <- Id -> LiftM [Id]
formerFreeVars Id
f
  let top_lvl_args :: [StgArg]
top_lvl_args = (Id -> StgArg) -> [Id] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map Id -> StgArg
StgVarArg [Id]
fvs' [StgArg] -> [StgArg] -> [StgArg]
forall a. [a] -> [a] -> [a]
++ [StgArg]
args'
  GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id -> [StgArg] -> GenStgExpr 'Vanilla
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f' [StgArg]
top_lvl_args)
liftExpr (StgConApp DataCon
con ConstructorNumber
mn [StgArg]
args [Type]
tys) = DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
mn ([StgArg] -> [Type] -> GenStgExpr 'Vanilla)
-> LiftM [StgArg] -> LiftM ([Type] -> GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StgArg -> LiftM StgArg) -> [StgArg] -> LiftM [StgArg]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args LiftM ([Type] -> GenStgExpr 'Vanilla)
-> LiftM [Type] -> LiftM (GenStgExpr 'Vanilla)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> LiftM [Type]
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
tys
liftExpr (StgOpApp StgOp
op [StgArg]
args Type
ty) = StgOp -> [StgArg] -> Type -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op ([StgArg] -> Type -> GenStgExpr 'Vanilla)
-> LiftM [StgArg] -> LiftM (Type -> GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StgArg -> LiftM StgArg) -> [StgArg] -> LiftM [StgArg]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args LiftM (Type -> GenStgExpr 'Vanilla)
-> LiftM Type -> LiftM (GenStgExpr 'Vanilla)
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftM Type
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty
liftExpr (StgCase GenStgExpr 'LiftLams
scrut BinderP 'LiftLams
info AltType
ty [GenStgAlt 'LiftLams]
alts) = do
  GenStgExpr 'Vanilla
scrut' <- GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
scrut
  Id
-> (Id -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr (BinderInfo -> Id
binderInfoBndr BinderInfo
BinderP 'LiftLams
info) ((Id -> LiftM (GenStgExpr 'Vanilla))
 -> LiftM (GenStgExpr 'Vanilla))
-> (Id -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a b. (a -> b) -> a -> b
$ \Id
bndr' -> do
    [OutStgAlt]
alts' <- (GenStgAlt 'LiftLams -> LiftM OutStgAlt)
-> [GenStgAlt 'LiftLams] -> LiftM [OutStgAlt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse GenStgAlt 'LiftLams -> LiftM OutStgAlt
liftAlt [GenStgAlt 'LiftLams]
alts
    GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GenStgExpr 'Vanilla
-> BinderP 'Vanilla
-> AltType
-> [OutStgAlt]
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase GenStgExpr 'Vanilla
scrut' Id
BinderP 'Vanilla
bndr' AltType
ty [OutStgAlt]
alts')
liftExpr (StgLet XLet 'LiftLams
scope LlStgBinding
bind GenStgExpr 'LiftLams
body)
  = TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
NotTopLevel LlStgBinding
bind Skeleton
XLet 'LiftLams
scope ((Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
 -> LiftM (GenStgExpr 'Vanilla))
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
      GenStgExpr 'Vanilla
body' <- GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
body
      case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
        Maybe (GenStgBinding 'Vanilla)
Nothing -> GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
body' -- withLiftedBindPairs decided to lift it and already added floats
        Just GenStgBinding 'Vanilla
bind' -> GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (XLet 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bind' GenStgExpr 'Vanilla
body')
liftExpr (StgLetNoEscape XLetNoEscape 'LiftLams
scope LlStgBinding
bind GenStgExpr 'LiftLams
body)
  = TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
NotTopLevel LlStgBinding
bind Skeleton
XLetNoEscape 'LiftLams
scope ((Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
 -> LiftM (GenStgExpr 'Vanilla))
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM (GenStgExpr 'Vanilla))
-> LiftM (GenStgExpr 'Vanilla)
forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
      GenStgExpr 'Vanilla
body' <- GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
body
      case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
        Maybe (GenStgBinding 'Vanilla)
Nothing -> String -> SDoc -> LiftM (GenStgExpr 'Vanilla)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"stgLiftLams" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Should never decide to lift LNEs")
        Just GenStgBinding 'Vanilla
bind' -> GenStgExpr 'Vanilla -> LiftM (GenStgExpr 'Vanilla)
forall a. a -> LiftM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (XLetNoEscape 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape XLetNoEscape 'Vanilla
NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bind' GenStgExpr 'Vanilla
body')

liftAlt :: LlStgAlt -> LiftM OutStgAlt
liftAlt :: GenStgAlt 'LiftLams -> LiftM OutStgAlt
liftAlt alt :: GenStgAlt 'LiftLams
alt@GenStgAlt{alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con=AltCon
_, alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs=[BinderP 'LiftLams]
infos, alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs=GenStgExpr 'LiftLams
rhs} =
  [Id] -> ([Id] -> LiftM OutStgAlt) -> LiftM OutStgAlt
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs ((BinderInfo -> Id) -> [BinderInfo] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderInfo]
[BinderP 'LiftLams]
infos) (([Id] -> LiftM OutStgAlt) -> LiftM OutStgAlt)
-> ([Id] -> LiftM OutStgAlt) -> LiftM OutStgAlt
forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' ->
    do !GenStgExpr 'Vanilla
rhs' <- GenStgExpr 'LiftLams -> LiftM (GenStgExpr 'Vanilla)
liftExpr GenStgExpr 'LiftLams
rhs
       OutStgAlt -> LiftM OutStgAlt
forall a. a -> LiftM a
forall (m :: * -> *) a. Monad m => a -> m a
return (OutStgAlt -> LiftM OutStgAlt) -> OutStgAlt -> LiftM OutStgAlt
forall a b. (a -> b) -> a -> b
$! GenStgAlt 'LiftLams
alt {alt_bndrs = bndrs', alt_rhs = rhs'}