{-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} -- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM' -- monad. module GHC.Stg.Lift.Monad ( decomposeStgBinding, mkStgBinding, Env (..), -- * #floats# Handling floats -- $floats FloatLang (..), collectFloats, -- Exported just for the docs -- * Transformation monad LiftM, runLiftM, -- ** Adding bindings startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding, -- ** Substitution and binders withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs, -- ** Occurrences substOcc, isLifted, formerFreeVars, liftedIdsExpander ) where #include "HsVersions.h" import GHC.Prelude import GHC.Types.Basic import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS ) import GHC.Driver.Session import GHC.Data.FastString import GHC.Types.Id import GHC.Types.Name import GHC.Utils.Outputable import GHC.Data.OrdList import GHC.Stg.Subst import GHC.Stg.Syntax import GHC.Core.Utils import GHC.Types.Unique.Supply import GHC.Utils.Misc import GHC.Types.Var.Env import GHC.Types.Var.Set import GHC.Core.Multiplicity import Control.Arrow ( second ) import Control.Monad.Trans.Class import Control.Monad.Trans.RWS.Strict ( RWST, runRWST ) import qualified Control.Monad.Trans.RWS.Strict as RWS import Control.Monad.Trans.Cont ( ContT (..) ) import Data.ByteString ( ByteString ) -- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@ decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]) decomposeStgBinding (StgRec pairs) = (Recursive, pairs) decomposeStgBinding (StgNonRec bndr rhs) = (NonRecursive, [(bndr, rhs)]) mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass mkStgBinding Recursive = StgRec mkStgBinding NonRecursive = uncurry StgNonRec . head -- | Environment threaded around in a scoped, @Reader@-like fashion. data Env = Env { e_dflags :: !DynFlags -- ^ Read-only. , e_subst :: !Subst -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId', -- because shadowing might make a closure's free variables unavailable at its -- call sites. Consider: -- @ -- let f y = x + y in let x = 4 in f x -- @ -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't -- available at its call site. , e_expansions :: !(IdEnv DIdSet) -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because -- they are bound at the top-level. Every occurrence must supply the formerly -- free variables of the lifted 'Id', so they in turn become free variables of -- the call sites. This environment tracks this expansion from lifted 'Id's to -- their free variables. -- -- 'InId's to 'OutId's. -- -- Invariant: 'Id's not present in this map won't be substituted. } emptyEnv :: DynFlags -> Env emptyEnv dflags = Env dflags emptySubst emptyVarEnv -- Note [Handling floats] -- ~~~~~~~~~~~~~~~~~~~~~~ -- $floats -- Consider the following expression: -- -- @ -- f x = -- let g y = ... f y ... -- in g x -- @ -- -- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@ -- binding above the binding for @f@: -- -- @ -- g f y = ... f y ... -- f x = g f x -- @ -- -- But this very unnecessarily turns a known call to @f@ into an unknown one, in -- addition to complicating matters for the analysis. -- Instead, we'd really like to put both functions in the same recursive group, -- thereby preserving the known call: -- -- @ -- Rec { -- g y = ... f y ... -- f x = g x -- } -- @ -- -- But we don't want this to happen for just /any/ binding. That would create -- possibly huge recursive groups in the process, calling for an occurrence -- analyser on STG. -- So, we need to track when we lift a binding out of a recursive RHS and add -- the binding to the same recursive group as the enclosing recursive binding -- (which must have either already been at the top-level or decided to be -- lifted itself in order to preserve the known call). -- -- This is done by expressing this kind of nesting structure as a 'Writer' over -- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to -- 'collectFloats'. -- API-wise, the analysis will not need to know about the whole 'FloatLang' -- business and will just manipulate it indirectly through actions in 'LiftM'. -- | We need to detect when we are lifting something out of the RHS of a -- recursive binding (c.f. "GHC.Stg.Lift.Monad#floats"), in which case that -- binding needs to be added to the same top-level recursive group. This -- requires we detect a certain nesting structure, which is encoded by -- 'StartBindingGroup' and 'EndBindingGroup'. -- -- Although 'collectFloats' will only ever care if the current binding to be -- lifted (through 'LiftedBinding') will occur inside such a binding group or -- not, e.g. doesn't care about the nesting level as long as its greater than 0. data FloatLang = StartBindingGroup | EndBindingGroup | PlainTopBinding OutStgTopBinding | LiftedBinding OutStgBinding instance Outputable FloatLang where ppr StartBindingGroup = char '(' ppr EndBindingGroup = char ')' ppr (PlainTopBinding StgTopStringLit{}) = text "<str>" ppr (PlainTopBinding (StgTopLifted b)) = ppr (LiftedBinding b) ppr (LiftedBinding bind) = (if isRec rec then char 'r' else char 'n') <+> ppr (map fst pairs) where (rec, pairs) = decomposeStgBinding bind -- | Flattens an expression in @['FloatLang']@ into an STG program, see "GHC.Stg.Lift.Monad#floats". -- Important pre-conditions: The nesting of opening 'StartBindinGroup's and -- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding -- group has at least one recursive binding inside. Otherwise there's no point -- in announcing the binding group in the first place and an @ASSERT@ will -- trigger. collectFloats :: [FloatLang] -> [OutStgTopBinding] collectFloats = go (0 :: Int) [] where go 0 [] [] = [] go _ _ [] = pprPanic "collectFloats" (text "unterminated group") go n binds (f:rest) = case f of StartBindingGroup -> go (n+1) binds rest EndBindingGroup | n == 0 -> pprPanic "collectFloats" (text "no group to end") | n == 1 -> StgTopLifted (merge_binds binds) : go 0 [] rest | otherwise -> go (n-1) binds rest PlainTopBinding top_bind | n == 0 -> top_bind : go n binds rest | otherwise -> pprPanic "collectFloats" (text "plain top binding inside group") LiftedBinding bind | n == 0 -> StgTopLifted (rm_cccs bind) : go n binds rest | otherwise -> go n (bind:binds) rest map_rhss f = uncurry mkStgBinding . second (map (second f)) . decomposeStgBinding rm_cccs = map_rhss removeRhsCCCS merge_binds binds = ASSERT( any is_rec binds ) StgRec (concatMap (snd . decomposeStgBinding . rm_cccs) binds) is_rec StgRec{} = True is_rec _ = False -- | Omitting this makes for strange closure allocation schemes that crash the -- GC. removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass removeRhsCCCS (StgRhsClosure ext ccs upd bndrs body) | isCurrentCCS ccs = StgRhsClosure ext dontCareCCS upd bndrs body removeRhsCCCS (StgRhsCon ccs con args) | isCurrentCCS ccs = StgRhsCon dontCareCCS con args removeRhsCCCS rhs = rhs -- | The analysis monad consists of the following 'RWST' components: -- -- * 'Env': Reader-like context. Contains a substitution, info about how -- how lifted identifiers are to be expanded into applications and details -- such as 'DynFlags'. -- -- * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program. -- -- * No pure state component -- -- * But wrapping around 'UniqSM' for generating fresh lifted binders. -- (The @uniqAway@ approach could give the same name to two different -- lifted binders, so this is necessary.) newtype LiftM a = LiftM { unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a } deriving (Functor, Applicative, Monad) instance HasDynFlags LiftM where getDynFlags = LiftM (RWS.asks e_dflags) instance MonadUnique LiftM where getUniqueSupplyM = LiftM (lift getUniqueSupplyM) getUniqueM = LiftM (lift getUniqueM) getUniquesM = LiftM (lift getUniquesM) runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding] runLiftM dflags us (LiftM m) = collectFloats (fromOL floats) where (_, _, floats) = initUs_ us (runRWST m (emptyEnv dflags) ()) -- | Writes a plain 'StgTopStringLit' to the output. addTopStringLit :: OutId -> ByteString -> LiftM () addTopStringLit id = LiftM . RWS.tell . unitOL . PlainTopBinding . StgTopStringLit id -- | Starts a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'. startBindingGroup :: LiftM () startBindingGroup = LiftM $ RWS.tell $ unitOL $ StartBindingGroup -- | Ends a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'. endBindingGroup :: LiftM () endBindingGroup = LiftM $ RWS.tell $ unitOL $ EndBindingGroup -- | Lifts a binding to top-level. Depending on whether it's declared inside -- a recursive RHS (see "GHC.Stg.Lift.Monad#floats" and 'collectFloats'), this might be added to -- an existing recursive top-level binding group. addLiftedBinding :: OutStgBinding -> LiftM () addLiftedBinding = LiftM . RWS.tell . unitOL . LiftedBinding -- | Takes a binder and a continuation which is called with the substituted -- binder. The continuation will be evaluated in a 'LiftM' context in which that -- binder is deemed in scope. Think of it as a 'RWS.local' computation: After -- the continuation finishes, the new binding won't be in scope anymore. withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a withSubstBndr bndr inner = LiftM $ do subst <- RWS.asks e_subst let (bndr', subst') = substBndr bndr subst RWS.local (\e -> e { e_subst = subst' }) (unwrapLiftM (inner bndr')) -- | See 'withSubstBndr'. withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a withSubstBndrs = runContT . traverse (ContT . withSubstBndr) -- | Similarly to 'withSubstBndr', this function takes a set of variables to -- abstract over, the binder to lift (and generate a fresh, substituted name -- for) and a continuation in which that fresh, lifted binder is in scope. -- -- It takes care of all the details involved with copying and adjusting the -- binder and fresh name generation. withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a withLiftedBndr abs_ids bndr inner = do uniq <- getUniqueM let str = "$l" ++ occNameString (getOccName bndr) let ty = mkLamTypes (dVarSetElems abs_ids) (idType bndr) let bndr' -- See Note [transferPolyIdInfo] in GHC.Types.Id. We need to do this at least -- for arity information. = transferPolyIdInfo bndr (dVarSetElems abs_ids) . mkSysLocal (mkFastString str) uniq Many $ ty LiftM $ RWS.local (\e -> e { e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e , e_expansions = extendVarEnv (e_expansions e) bndr abs_ids }) (unwrapLiftM (inner bndr')) -- | See 'withLiftedBndr'. withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a withLiftedBndrs abs_ids = runContT . traverse (ContT . withLiftedBndr abs_ids) -- | Substitutes a binder /occurrence/, which was brought in scope earlier by -- 'withSubstBndr' \/ 'withLiftedBndr'. substOcc :: Id -> LiftM Id substOcc id = LiftM (RWS.asks (lookupIdSubst id . e_subst)) -- | Whether the given binding was decided to be lambda lifted. isLifted :: InId -> LiftM Bool isLifted bndr = LiftM (RWS.asks (elemVarEnv bndr . e_expansions)) -- | Returns an empty list for a binding that was not lifted and the list of all -- local variables the binding abstracts over (so, exactly the additional -- arguments at adjusted call sites) otherwise. formerFreeVars :: InId -> LiftM [OutId] formerFreeVars f = LiftM $ do expansions <- RWS.asks e_expansions pure $ case lookupVarEnv expansions f of Nothing -> [] Just fvs -> dVarSetElems fvs -- | Creates an /expander function/ for the current set of lifted binders. -- This expander function will replace any 'InId' by their corresponding 'OutId' -- and, in addition, will expand any lifted binders by the former free variables -- it abstracts over. liftedIdsExpander :: LiftM (DIdSet -> DIdSet) liftedIdsExpander = LiftM $ do expansions <- RWS.asks e_expansions subst <- RWS.asks e_subst -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope" -- warnings generated by 'lookupIdSubst' due to local bindings within RHS. -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much -- trouble. let go set fv = case lookupVarEnv expansions fv of Nothing -> extendDVarSet set (noWarnLookupIdSubst fv subst) -- Not lifted Just fvs' -> unionDVarSet set fvs' let expander fvs = foldl' go emptyDVarSet (dVarSetElems fvs) pure expander