{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.Lift.Monad (
decomposeStgBinding, mkStgBinding,
Env (..),
FloatLang (..), collectFloats,
LiftM, runLiftM,
startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
substOcc, isLifted, formerFreeVars, liftedIdsExpander
) where
#include "HsVersions.h"
import GHC.Prelude
import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Driver.Session
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Utils.Panic
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity
import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])
mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = (BinderP pass -> GenStgRhs pass -> GenStgBinding pass)
-> (BinderP pass, GenStgRhs pass) -> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry BinderP pass -> GenStgRhs pass -> GenStgBinding pass
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((BinderP pass, GenStgRhs pass) -> GenStgBinding pass)
-> ([(BinderP pass, GenStgRhs pass)]
-> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(BinderP pass, GenStgRhs pass)] -> (BinderP pass, GenStgRhs pass)
forall a. [a] -> a
head
data Env
= Env
{ Env -> DynFlags
e_dflags :: !DynFlags
, Env -> Subst
e_subst :: !Subst
, Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
}
emptyEnv :: DynFlags -> Env
emptyEnv :: DynFlags -> Env
emptyEnv DynFlags
dflags = DynFlags -> Subst -> IdEnv DIdSet -> Env
Env DynFlags
dflags Subst
emptySubst IdEnv DIdSet
forall a. VarEnv a
emptyVarEnv
data FloatLang
= StartBindingGroup
| EndBindingGroup
| PlainTopBinding OutStgTopBinding
| LiftedBinding OutStgBinding
instance Outputable FloatLang where
ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
char Char
'('
ppr FloatLang
EndBindingGroup = Char -> SDoc
char Char
')'
ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
text String
"<str>"
ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = FloatLang -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
char Char
'r' else Char -> SDoc
char Char
'n') SDoc -> SDoc -> SDoc
<+> [Id] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)
where
(RecFlag
rec, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = GenStgBinding 'Vanilla
-> (RecFlag, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = Int
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
forall {a}.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
where
go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
go a
_ [GenStgBinding 'Vanilla]
_ [] = String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"unterminated group")
go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
FloatLang
EndBindingGroup
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"no group to end")
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted ([GenStgBinding 'Vanilla] -> GenStgBinding 'Vanilla
forall {t :: * -> *} {pass :: StgPass}.
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"plain top binding inside group")
LiftedBinding GenStgBinding 'Vanilla
bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> GenStgBinding 'Vanilla
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindGenStgBinding 'Vanilla
-> [GenStgBinding 'Vanilla] -> [GenStgBinding 'Vanilla]
forall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest
map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = (RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass)
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding ((RecFlag, [(BinderP pass, GenStgRhs pass)]) -> GenStgBinding pass)
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (((BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a -> b) -> [a] -> [b]
map ((GenStgRhs pass -> GenStgRhs pass)
-> (BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass} {pass :: StgPass}.
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = ASSERT( any is_rec binds )
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ((GenStgBinding pass -> [(BinderP pass, GenStgRhs pass)])
-> t (GenStgBinding pass) -> [(BinderP pass, GenStgRhs pass)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a, b) -> b
snd ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)])
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> [(BinderP pass, GenStgRhs pass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
is_rec GenStgBinding pass
_ = Bool
False
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs
newtype LiftM a
= LiftM { forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
deriving ((forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
Functor LiftM
-> (forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: forall a. a -> LiftM a
$cpure :: forall a. a -> LiftM a
Applicative, Applicative LiftM
Applicative LiftM
-> (forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
Monad)
instance HasDynFlags LiftM where
getDynFlags :: LiftM DynFlags
getDynFlags = RWST Env (OrdList FloatLang) () UniqSM DynFlags -> LiftM DynFlags
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> DynFlags)
-> RWST Env (OrdList FloatLang) () UniqSM DynFlags
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> DynFlags
e_dflags)
instance MonadUnique LiftM where
getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = RWST Env (OrdList FloatLang) () UniqSM UniqSupply
-> LiftM UniqSupply
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM UniqSupply
-> RWST Env (OrdList FloatLang) () UniqSM UniqSupply
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
getUniqueM :: LiftM Unique
getUniqueM = RWST Env (OrdList FloatLang) () UniqSM Unique -> LiftM Unique
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM Unique -> RWST Env (OrdList FloatLang) () UniqSM Unique
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
getUniquesM :: LiftM [Unique]
getUniquesM = RWST Env (OrdList FloatLang) () UniqSM [Unique] -> LiftM [Unique]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM [Unique] -> RWST Env (OrdList FloatLang) () UniqSM [Unique]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM [Unique]
forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM DynFlags
dflags UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (OrdList FloatLang -> [FloatLang]
forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
where
(()
_, ()
_, OrdList FloatLang
floats) = UniqSupply
-> UniqSM ((), (), OrdList FloatLang)
-> ((), (), OrdList FloatLang)
forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (RWST Env (OrdList FloatLang) () UniqSM ()
-> Env -> () -> UniqSM ((), (), OrdList FloatLang)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (DynFlags -> Env
emptyEnv DynFlags
dflags) ())
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (ByteString -> RWST Env (OrdList FloatLang) () UniqSM ())
-> ByteString
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (ByteString -> OrdList FloatLang)
-> ByteString
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (ByteString -> FloatLang) -> ByteString -> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding (GenStgTopBinding 'Vanilla -> FloatLang)
-> (ByteString -> GenStgTopBinding 'Vanilla)
-> ByteString
-> FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> ByteString -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ())
-> GenStgBinding 'Vanilla
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (GenStgBinding 'Vanilla -> OrdList FloatLang)
-> GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (GenStgBinding 'Vanilla -> FloatLang)
-> GenStgBinding 'Vanilla
-> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ do
Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let (Id
bndr', Subst
subst') = Id -> Subst -> (Id, Subst)
substBndr Id
bndr Subst
subst
(Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_subst :: Subst
e_subst = Subst
subst' }) (LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> (Id -> LiftM a) -> LiftM a
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
Unique
uniq <- LiftM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
let str :: String
str = String
"$l" String -> String -> String
forall a. [a] -> [a] -> [a]
++ OccName -> String
occNameString (Id -> OccName
forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
let bndr' :: Id
bndr'
= Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
(Id -> Id) -> (Type -> Id) -> Type -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FastString -> Unique -> Type -> Type -> Id
mkSysLocal (String -> FastString
mkFastString String
str) Unique
uniq Type
Many
(Type -> Id) -> Type -> Id
forall a b. (a -> b) -> a -> b
$ Type
ty
RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local
(\Env
e -> Env
e
{ e_subst :: Subst
e_subst = Id -> Id -> Subst -> Subst
extendSubst Id
bndr Id
bndr' (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ Id -> Subst -> Subst
extendInScope Id
bndr' (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ Env -> Subst
e_subst Env
e
, e_expansions :: IdEnv DIdSet
e_expansions = IdEnv DIdSet -> Id -> DIdSet -> IdEnv DIdSet
forall a. VarEnv a -> Id -> a -> VarEnv a
extendVarEnv (Env -> IdEnv DIdSet
e_expansions Env
e) Id
bndr DIdSet
abs_ids
})
(LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = RWST Env (OrdList FloatLang) () UniqSM Id -> LiftM Id
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Id) -> RWST Env (OrdList FloatLang) () UniqSM Id
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
lookupIdSubst Id
id (Subst -> Id) -> (Env -> Subst) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = RWST Env (OrdList FloatLang) () UniqSM Bool -> LiftM Bool
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Bool) -> RWST Env (OrdList FloatLang) () UniqSM Bool
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (Id -> IdEnv DIdSet -> Bool
forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr (IdEnv DIdSet -> Bool) -> (Env -> IdEnv DIdSet) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id])
-> RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
[Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id])
-> [Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall a b. (a -> b) -> a -> b
$ case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
f of
Maybe DIdSet
Nothing -> []
Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet))
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let go :: DIdSet -> Id -> DIdSet
go DIdSet
set Id
fv = case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst)
Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
let expander :: DIdSet -> DIdSet
expander DIdSet
fvs = (DIdSet -> Id -> DIdSet) -> DIdSet -> [Id] -> DIdSet
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
(DIdSet -> DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DIdSet -> DIdSet
expander