{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.Lift.Monad (
decomposeStgBinding, mkStgBinding,
Env (..),
FloatLang (..), collectFloats,
LiftM, runLiftM,
getConfig,
startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
substOcc, isLifted, formerFreeVars, liftedIdsExpander
) where
import GHC.Prelude
import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Lift.Config
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity
import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])
mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = (BinderP pass -> GenStgRhs pass -> GenStgBinding pass)
-> (BinderP pass, GenStgRhs pass) -> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry BinderP pass -> GenStgRhs pass -> GenStgBinding pass
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((BinderP pass, GenStgRhs pass) -> GenStgBinding pass)
-> ([(BinderP pass, GenStgRhs pass)]
-> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(BinderP pass, GenStgRhs pass)] -> (BinderP pass, GenStgRhs pass)
forall a. HasCallStack => [a] -> a
head
data Env
= Env
{ Env -> StgLiftConfig
e_config :: StgLiftConfig
, Env -> Subst
e_subst :: !Subst
, Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
}
emptyEnv :: StgLiftConfig -> Env
emptyEnv :: StgLiftConfig -> Env
emptyEnv StgLiftConfig
cfg = Env
{ e_config :: StgLiftConfig
e_config = StgLiftConfig
cfg
, e_subst :: Subst
e_subst = Subst
emptySubst
, e_expansions :: IdEnv DIdSet
e_expansions = IdEnv DIdSet
forall a. VarEnv a
emptyVarEnv
}
data FloatLang
= StartBindingGroup
| EndBindingGroup
| PlainTopBinding OutStgTopBinding
| LiftedBinding OutStgBinding
instance Outputable FloatLang where
ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'('
ppr FloatLang
EndBindingGroup = Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
')'
ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"<str>"
ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = FloatLang -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'r' else Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'n') SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Id] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)
where
(RecFlag
rec, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = GenStgBinding 'Vanilla
-> (RecFlag, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = Int
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
forall {a}.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
where
go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
go a
_ [GenStgBinding 'Vanilla]
_ [] = String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"unterminated group")
go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
FloatLang
EndBindingGroup
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"no group to end")
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted ([GenStgBinding 'Vanilla] -> GenStgBinding 'Vanilla
forall {t :: * -> *} {pass :: StgPass}.
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"plain top binding inside group")
LiftedBinding GenStgBinding 'Vanilla
bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> GenStgBinding 'Vanilla
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindGenStgBinding 'Vanilla
-> [GenStgBinding 'Vanilla] -> [GenStgBinding 'Vanilla]
forall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest
map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = (RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass)
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding ((RecFlag, [(BinderP pass, GenStgRhs pass)]) -> GenStgBinding pass)
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (((BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a -> b) -> [a] -> [b]
map ((GenStgRhs pass -> GenStgRhs pass)
-> (BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass} {pass :: StgPass}.
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = Bool -> GenStgBinding pass -> GenStgBinding pass
forall a. HasCallStack => Bool -> a -> a
assert ((GenStgBinding pass -> Bool) -> t (GenStgBinding pass) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any GenStgBinding pass -> Bool
forall {pass :: StgPass}. GenStgBinding pass -> Bool
is_rec t (GenStgBinding pass)
binds) (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass -> GenStgBinding pass
forall a b. (a -> b) -> a -> b
$
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ((GenStgBinding pass -> [(BinderP pass, GenStgRhs pass)])
-> t (GenStgBinding pass) -> [(BinderP pass, GenStgRhs pass)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a, b) -> b
snd ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)])
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> [(BinderP pass, GenStgRhs pass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
is_rec GenStgBinding pass
_ = Bool
False
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs
newtype LiftM a
= LiftM { forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
deriving ((forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$c<$ :: forall a b. a -> LiftM b -> LiftM a
<$ :: forall a b. a -> LiftM b -> LiftM a
Functor, Functor LiftM
Functor LiftM =>
(forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> LiftM a
pure :: forall a. a -> LiftM a
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
Applicative, Applicative LiftM
Applicative LiftM =>
(forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$creturn :: forall a. a -> LiftM a
return :: forall a. a -> LiftM a
Monad)
instance MonadUnique LiftM where
getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = RWST Env (OrdList FloatLang) () UniqSM UniqSupply
-> LiftM UniqSupply
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM UniqSupply
-> RWST Env (OrdList FloatLang) () UniqSM UniqSupply
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
getUniqueM :: LiftM Unique
getUniqueM = RWST Env (OrdList FloatLang) () UniqSM Unique -> LiftM Unique
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM Unique -> RWST Env (OrdList FloatLang) () UniqSM Unique
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
getUniquesM :: LiftM [Unique]
getUniquesM = RWST Env (OrdList FloatLang) () UniqSM [Unique] -> LiftM [Unique]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM [Unique] -> RWST Env (OrdList FloatLang) () UniqSM [Unique]
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM [Unique]
forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)
runLiftM :: StgLiftConfig -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: StgLiftConfig
-> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM StgLiftConfig
cfg UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (OrdList FloatLang -> [FloatLang]
forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
where
(()
_, ()
_, OrdList FloatLang
floats) = UniqSupply
-> UniqSM ((), (), OrdList FloatLang)
-> ((), (), OrdList FloatLang)
forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (RWST Env (OrdList FloatLang) () UniqSM ()
-> Env -> () -> UniqSM ((), (), OrdList FloatLang)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (StgLiftConfig -> Env
emptyEnv StgLiftConfig
cfg) ())
getConfig :: LiftM StgLiftConfig
getConfig :: LiftM StgLiftConfig
getConfig = RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig)
-> RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig
forall a b. (a -> b) -> a -> b
$ Env -> StgLiftConfig
e_config (Env -> StgLiftConfig)
-> RWST Env (OrdList FloatLang) () UniqSM Env
-> RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RWST Env (OrdList FloatLang) () UniqSM Env
forall w (m :: * -> *) r s. (Monoid w, Monad m) => RWST r w s m r
RWS.ask
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (ByteString -> RWST Env (OrdList FloatLang) () UniqSM ())
-> ByteString
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (ByteString -> OrdList FloatLang)
-> ByteString
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (ByteString -> FloatLang) -> ByteString -> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding (GenStgTopBinding 'Vanilla -> FloatLang)
-> (ByteString -> GenStgTopBinding 'Vanilla)
-> ByteString
-> FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> ByteString -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ())
-> GenStgBinding 'Vanilla
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (GenStgBinding 'Vanilla -> OrdList FloatLang)
-> GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (GenStgBinding 'Vanilla -> FloatLang)
-> GenStgBinding 'Vanilla
-> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ do
Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let (Id
bndr', Subst
subst') = Id -> Subst -> (Id, Subst)
substBndr Id
bndr Subst
subst
(Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_subst = subst' }) (LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> (Id -> LiftM a) -> LiftM a
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
Unique
uniq <- LiftM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
let str :: FastString
str = String -> FastString
fsLit String
"$l" FastString -> FastString -> FastString
`appendFS` OccName -> FastString
occNameFS (Id -> OccName
forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
let bndr' :: Id
bndr'
= Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
(Id -> Id) -> (Type -> Id) -> Type -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FastString -> Unique -> Type -> Type -> Id
mkSysLocal FastString
str Unique
uniq Type
ManyTy
(Type -> Id) -> Type -> Id
forall a b. (a -> b) -> a -> b
$ Type
ty
RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env)
-> RWST Env (OrdList FloatLang) () UniqSM a
-> RWST Env (OrdList FloatLang) () UniqSM a
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local
(\Env
e -> Env
e
{ e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e
, e_expansions = extendVarEnv (e_expansions e) bndr abs_ids
})
(LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = RWST Env (OrdList FloatLang) () UniqSM Id -> LiftM Id
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Id) -> RWST Env (OrdList FloatLang) () UniqSM Id
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
lookupIdSubst Id
id (Subst -> Id) -> (Env -> Subst) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = RWST Env (OrdList FloatLang) () UniqSM Bool -> LiftM Bool
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Bool) -> RWST Env (OrdList FloatLang) () UniqSM Bool
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (Id -> IdEnv DIdSet -> Bool
forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr (IdEnv DIdSet -> Bool) -> (Env -> IdEnv DIdSet) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id])
-> RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
[Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall a. a -> RWST Env (OrdList FloatLang) () UniqSM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id])
-> [Id] -> RWST Env (OrdList FloatLang) () UniqSM [Id]
forall a b. (a -> b) -> a -> b
$ case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
f of
Maybe DIdSet
Nothing -> []
Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet))
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
Subst
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let go :: DIdSet -> Id -> DIdSet
go DIdSet
set Id
fv = case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst)
Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
let expander :: DIdSet -> DIdSet
expander DIdSet
fvs = (DIdSet -> Id -> DIdSet) -> DIdSet -> [Id] -> DIdSet
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
(DIdSet -> DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
forall a. a -> RWST Env (OrdList FloatLang) () UniqSM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DIdSet -> DIdSet
expander