module GHC.Core.Opt.Simplify.Env (
setMode, getMode, updMode, seDynFlags, seUnfoldingOpts, seLogger,
SimplEnv(..), pprSimplEnv,
mkSimplEnv, extendIdSubst,
extendTvSubst, extendCvSubst,
zapSubstEnv, setSubstEnv, bumpCaseDepth,
getInScope, setInScopeFromE, setInScopeFromF,
setInScopeSet, modifyInScope, addNewInScopeIds,
getSimplRules, enterRecGroupRHSs,
SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
simplBinder, simplBinders,
substTy, substTyVar, getTCvSubst,
substCo, substCoVar,
SimplFloats(..), emptyFloats, isEmptyFloats, mkRecFloats,
mkFloatBind, addLetFloats, addJoinFloats, addFloats,
extendFloats, wrapFloats,
doFloatFromRhs, getTopFloatBinds,
LetFloats, letFloatBinds, emptyLetFloats, unitLetFloat,
addLetFlts, mapLetFloats,
JoinFloat, JoinFloats, emptyJoinFloats,
wrapJoinFloats, wrapJoinFloatsX, unitJoinFloat, addJoinFlts
) where
import GHC.Prelude
import GHC.Core.Opt.Simplify.Monad
import GHC.Core.Opt.Monad ( SimplMode(..) )
import GHC.Core
import GHC.Core.Utils
import GHC.Core.Multiplicity ( scaleScaled )
import GHC.Core.Unfold
import GHC.Types.Var
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Data.OrdList
import GHC.Data.Graph.UnVar
import GHC.Types.Id as Id
import GHC.Core.Make ( mkWildValBinder )
import GHC.Driver.Session ( DynFlags )
import GHC.Builtin.Types
import GHC.Core.TyCo.Rep ( TyCoBinder(..) )
import qualified GHC.Core.Type as Type
import GHC.Core.Type hiding ( substTy, substTyVar, substTyVarBndr, extendTvSubst, extendCvSubst )
import qualified GHC.Core.Coercion as Coercion
import GHC.Core.Coercion hiding ( substCo, substCoVar, substCoVarBndr )
import GHC.Types.Basic
import GHC.Utils.Monad
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Utils.Panic.Plain
import GHC.Utils.Misc
import GHC.Utils.Logger
import GHC.Types.Unique.FM ( pprUniqFM )
import Data.List (mapAccumL)
data SimplEnv
= SimplEnv {
SimplEnv -> SimplMode
seMode :: !SimplMode
, SimplEnv -> TvSubstEnv
seTvSubst :: TvSubstEnv
, SimplEnv -> CvSubstEnv
seCvSubst :: CvSubstEnv
, SimplEnv -> SimplIdSubst
seIdSubst :: SimplIdSubst
, SimplEnv -> UnVarSet
seRecIds :: !UnVarSet
, SimplEnv -> InScopeSet
seInScope :: !InScopeSet
, SimplEnv -> JoinArity
seCaseDepth :: !Int
}
data SimplFloats
= SimplFloats
{
SimplFloats -> LetFloats
sfLetFloats :: LetFloats
, SimplFloats -> JoinFloats
sfJoinFloats :: JoinFloats
, SimplFloats -> InScopeSet
sfInScope :: InScopeSet
}
instance Outputable SimplFloats where
ppr :: SimplFloats -> SDoc
ppr (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
is })
= String -> SDoc
text String
"SimplFloats"
SDoc -> SDoc -> SDoc
<+> SDoc -> SDoc
braces ([SDoc] -> SDoc
vcat [ String -> SDoc
text String
"lets: " SDoc -> SDoc -> SDoc
<+> LetFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr LetFloats
lf
, String -> SDoc
text String
"joins:" SDoc -> SDoc -> SDoc
<+> JoinFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr JoinFloats
jf
, String -> SDoc
text String
"in_scope:" SDoc -> SDoc -> SDoc
<+> InScopeSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr InScopeSet
is ])
emptyFloats :: SimplEnv -> SimplFloats
emptyFloats :: SimplEnv -> SimplFloats
emptyFloats SimplEnv
env
= SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
emptyLetFloats
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats
, sfInScope :: InScopeSet
sfInScope = SimplEnv -> InScopeSet
seInScope SimplEnv
env }
isEmptyFloats :: SimplFloats -> Bool
isEmptyFloats :: SimplFloats -> Bool
isEmptyFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
fs FloatFlag
_
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
js })
= Bool -> SDoc -> Bool -> Bool
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
js) (JoinFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr JoinFloats
js ) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
fs
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv SimplEnv
env
= [SDoc] -> SDoc
vcat [String -> SDoc
text String
"TvSubst:" SDoc -> SDoc -> SDoc
<+> TvSubstEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> TvSubstEnv
seTvSubst SimplEnv
env),
String -> SDoc
text String
"CvSubst:" SDoc -> SDoc -> SDoc
<+> CvSubstEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> CvSubstEnv
seCvSubst SimplEnv
env),
String -> SDoc
text String
"IdSubst:" SDoc -> SDoc -> SDoc
<+> SDoc
id_subst_doc,
String -> SDoc
text String
"InScope:" SDoc -> SDoc -> SDoc
<+> SDoc
in_scope_vars_doc
]
where
id_subst_doc :: SDoc
id_subst_doc = (SimplSR -> SDoc) -> SimplIdSubst -> SDoc
forall a key. (a -> SDoc) -> UniqFM key a -> SDoc
pprUniqFM SimplSR -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> SimplIdSubst
seIdSubst SimplEnv
env)
in_scope_vars_doc :: SDoc
in_scope_vars_doc = VarSet -> ([OutId] -> SDoc) -> SDoc
pprVarSet (InScopeSet -> VarSet
getInScopeVars (SimplEnv -> InScopeSet
seInScope SimplEnv
env))
([SDoc] -> SDoc
vcat ([SDoc] -> SDoc) -> ([OutId] -> [SDoc]) -> [OutId] -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OutId -> SDoc) -> [OutId] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map OutId -> SDoc
ppr_one)
ppr_one :: OutId -> SDoc
ppr_one OutId
v | OutId -> Bool
isId OutId
v = OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
v SDoc -> SDoc -> SDoc
<+> Unfolding -> SDoc
forall a. Outputable a => a -> SDoc
ppr (OutId -> Unfolding
idUnfolding OutId
v)
| Bool
otherwise = OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
v
type SimplIdSubst = IdEnv SimplSR
data SimplSR
= DoneEx OutExpr (Maybe JoinArity)
| DoneId OutId
| ContEx TvSubstEnv
CvSubstEnv
SimplIdSubst
InExpr
instance Outputable SimplSR where
ppr :: SimplSR -> SDoc
ppr (DoneId OutId
v) = String -> SDoc
text String
"DoneId" SDoc -> SDoc -> SDoc
<+> OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
v
ppr (DoneEx OutExpr
e Maybe JoinArity
mj) = String -> SDoc
text String
"DoneEx" SDoc -> SDoc -> SDoc
<> SDoc
pp_mj SDoc -> SDoc -> SDoc
<+> OutExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutExpr
e
where
pp_mj :: SDoc
pp_mj = case Maybe JoinArity
mj of
Maybe JoinArity
Nothing -> SDoc
empty
Just JoinArity
n -> SDoc -> SDoc
parens (JoinArity -> SDoc
int JoinArity
n)
ppr (ContEx TvSubstEnv
_tv CvSubstEnv
_cv SimplIdSubst
_id OutExpr
e) = [SDoc] -> SDoc
vcat [String -> SDoc
text String
"ContEx" SDoc -> SDoc -> SDoc
<+> OutExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutExpr
e ]
mkSimplEnv :: SimplMode -> SimplEnv
mkSimplEnv :: SimplMode -> SimplEnv
mkSimplEnv SimplMode
mode
= SimplEnv { seMode :: SimplMode
seMode = SimplMode
mode
, seInScope :: InScopeSet
seInScope = InScopeSet
init_in_scope
, seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
forall a. VarEnv a
emptyVarEnv
, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
forall a. VarEnv a
emptyVarEnv
, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
forall a. VarEnv a
emptyVarEnv
, seRecIds :: UnVarSet
seRecIds = UnVarSet
emptyUnVarSet
, seCaseDepth :: JoinArity
seCaseDepth = JoinArity
0 }
init_in_scope :: InScopeSet
init_in_scope :: InScopeSet
init_in_scope = VarSet -> InScopeSet
mkInScopeSet (OutId -> VarSet
unitVarSet (Kind -> Kind -> OutId
mkWildValBinder Kind
Many Kind
unitTy))
getMode :: SimplEnv -> SimplMode
getMode :: SimplEnv -> SimplMode
getMode SimplEnv
env = SimplEnv -> SimplMode
seMode SimplEnv
env
seDynFlags :: SimplEnv -> DynFlags
seDynFlags :: SimplEnv -> DynFlags
seDynFlags SimplEnv
env = SimplMode -> DynFlags
sm_dflags (SimplEnv -> SimplMode
seMode SimplEnv
env)
seLogger :: SimplEnv -> Logger
seLogger :: SimplEnv -> Logger
seLogger SimplEnv
env = SimplMode -> Logger
sm_logger (SimplEnv -> SimplMode
seMode SimplEnv
env)
seUnfoldingOpts :: SimplEnv -> UnfoldingOpts
seUnfoldingOpts :: SimplEnv -> UnfoldingOpts
seUnfoldingOpts SimplEnv
env = SimplMode -> UnfoldingOpts
sm_uf_opts (SimplEnv -> SimplMode
seMode SimplEnv
env)
setMode :: SimplMode -> SimplEnv -> SimplEnv
setMode :: SimplMode -> SimplEnv -> SimplEnv
setMode SimplMode
mode SimplEnv
env = SimplEnv
env { seMode :: SimplMode
seMode = SimplMode
mode }
updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
updMode SimplMode -> SimplMode
upd SimplEnv
env
=
let mode :: SimplMode
mode = SimplMode -> SimplMode
upd (SimplMode -> SimplMode) -> SimplMode -> SimplMode
forall a b. (a -> b) -> a -> b
$! (SimplEnv -> SimplMode
seMode SimplEnv
env)
in SimplEnv
env { seMode :: SimplMode
seMode = SimplMode
mode }
bumpCaseDepth :: SimplEnv -> SimplEnv
bumpCaseDepth :: SimplEnv -> SimplEnv
bumpCaseDepth SimplEnv
env = SimplEnv
env { seCaseDepth :: JoinArity
seCaseDepth = SimplEnv -> JoinArity
seCaseDepth SimplEnv
env JoinArity -> JoinArity -> JoinArity
forall a. Num a => a -> a -> a
+ JoinArity
1 }
extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
extendIdSubst :: SimplEnv -> OutId -> SimplSR -> SimplEnv
extendIdSubst env :: SimplEnv
env@(SimplEnv {seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
subst}) OutId
var SimplSR
res
= Bool -> SDoc -> SimplEnv -> SimplEnv
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (OutId -> Bool
isId OutId
var Bool -> Bool -> Bool
&& Bool -> Bool
not (OutId -> Bool
isCoVar OutId
var)) (OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
var) (SimplEnv -> SimplEnv) -> SimplEnv -> SimplEnv
forall a b. (a -> b) -> a -> b
$
SimplEnv
env { seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst -> OutId -> SimplSR -> SimplIdSubst
forall a. VarEnv a -> OutId -> a -> VarEnv a
extendVarEnv SimplIdSubst
subst OutId
var SimplSR
res }
extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
extendTvSubst :: SimplEnv -> OutId -> Kind -> SimplEnv
extendTvSubst env :: SimplEnv
env@(SimplEnv {seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tsubst}) OutId
var Kind
res
= Bool -> SDoc -> SimplEnv -> SimplEnv
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (OutId -> Bool
isTyVar OutId
var) (OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
var SDoc -> SDoc -> SDoc
$$ Kind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Kind
res) (SimplEnv -> SimplEnv) -> SimplEnv -> SimplEnv
forall a b. (a -> b) -> a -> b
$
SimplEnv
env {seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv -> OutId -> Kind -> TvSubstEnv
forall a. VarEnv a -> OutId -> a -> VarEnv a
extendVarEnv TvSubstEnv
tsubst OutId
var Kind
res}
extendCvSubst :: SimplEnv -> CoVar -> Coercion -> SimplEnv
extendCvSubst :: SimplEnv -> OutId -> Coercion -> SimplEnv
extendCvSubst env :: SimplEnv
env@(SimplEnv {seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
csubst}) OutId
var Coercion
co
= Bool -> SimplEnv -> SimplEnv
forall a. HasCallStack => Bool -> a -> a
assert (OutId -> Bool
isCoVar OutId
var) (SimplEnv -> SimplEnv) -> SimplEnv -> SimplEnv
forall a b. (a -> b) -> a -> b
$
SimplEnv
env {seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv -> OutId -> Coercion -> CvSubstEnv
forall a. VarEnv a -> OutId -> a -> VarEnv a
extendVarEnv CvSubstEnv
csubst OutId
var Coercion
co}
getInScope :: SimplEnv -> InScopeSet
getInScope :: SimplEnv -> InScopeSet
getInScope SimplEnv
env = SimplEnv -> InScopeSet
seInScope SimplEnv
env
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet SimplEnv
env InScopeSet
in_scope = SimplEnv
env {seInScope :: InScopeSet
seInScope = InScopeSet
in_scope}
setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeFromE SimplEnv
rhs_env SimplEnv
here_env = SimplEnv
rhs_env { seInScope :: InScopeSet
seInScope = SimplEnv -> InScopeSet
seInScope SimplEnv
here_env }
setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
setInScopeFromF SimplEnv
env SimplFloats
floats = SimplEnv
env { seInScope :: InScopeSet
seInScope = SimplFloats -> InScopeSet
sfInScope SimplFloats
floats }
addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
addNewInScopeIds :: SimplEnv -> [OutId] -> SimplEnv
addNewInScopeIds env :: SimplEnv
env@(SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
id_subst }) [OutId]
vs
= let !in_scope1 :: InScopeSet
in_scope1 = InScopeSet
in_scope InScopeSet -> [OutId] -> InScopeSet
`extendInScopeSetList` [OutId]
vs
!id_subst1 :: SimplIdSubst
id_subst1 = SimplIdSubst
id_subst SimplIdSubst -> [OutId] -> SimplIdSubst
forall a. VarEnv a -> [OutId] -> VarEnv a
`delVarEnvList` [OutId]
vs
in
SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope1,
seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
id_subst1 }
modifyInScope :: SimplEnv -> CoreBndr -> SimplEnv
modifyInScope :: SimplEnv -> OutId -> SimplEnv
modifyInScope env :: SimplEnv
env@(SimplEnv {seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope}) OutId
v
= SimplEnv
env {seInScope :: InScopeSet
seInScope = InScopeSet -> OutId -> InScopeSet
extendInScopeSet InScopeSet
in_scope OutId
v}
enterRecGroupRHSs :: SimplEnv -> [OutBndr] -> (SimplEnv -> SimplM (r, SimplEnv))
-> SimplM (r, SimplEnv)
enterRecGroupRHSs :: forall r.
SimplEnv
-> [OutId]
-> (SimplEnv -> SimplM (r, SimplEnv))
-> SimplM (r, SimplEnv)
enterRecGroupRHSs SimplEnv
env [OutId]
bndrs SimplEnv -> SimplM (r, SimplEnv)
k = do
(r
r, SimplEnv
env'') <- SimplEnv -> SimplM (r, SimplEnv)
k SimplEnv
env{seRecIds :: UnVarSet
seRecIds = [OutId] -> UnVarSet -> UnVarSet
extendUnVarSetList [OutId]
bndrs (SimplEnv -> UnVarSet
seRecIds SimplEnv
env)}
(r, SimplEnv) -> SimplM (r, SimplEnv)
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return (r
r, SimplEnv
env''{seRecIds :: UnVarSet
seRecIds = SimplEnv -> UnVarSet
seRecIds SimplEnv
env})
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv SimplEnv
env = SimplEnv
env {seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
forall a. VarEnv a
emptyVarEnv, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
forall a. VarEnv a
emptyVarEnv, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
forall a. VarEnv a
emptyVarEnv}
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv SimplEnv
env TvSubstEnv
tvs CvSubstEnv
cvs SimplIdSubst
ids = SimplEnv
env { seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tvs, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cvs, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
ids }
mkContEx :: SimplEnv -> InExpr -> SimplSR
mkContEx :: SimplEnv -> OutExpr -> SimplSR
mkContEx (SimplEnv { seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tvs, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cvs, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) OutExpr
e = TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> OutExpr -> SimplSR
ContEx TvSubstEnv
tvs CvSubstEnv
cvs SimplIdSubst
ids OutExpr
e
data LetFloats = LetFloats (OrdList OutBind) FloatFlag
type JoinFloat = OutBind
type JoinFloats = OrdList JoinFloat
data FloatFlag
= FltLifted
| FltOkSpec
| FltCareful
instance Outputable LetFloats where
ppr :: LetFloats -> SDoc
ppr (LetFloats JoinFloats
binds FloatFlag
ff) = FloatFlag -> SDoc
forall a. Outputable a => a -> SDoc
ppr FloatFlag
ff SDoc -> SDoc -> SDoc
$$ [Bind OutId] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (JoinFloats -> [Bind OutId]
forall a. OrdList a -> [a]
fromOL JoinFloats
binds)
instance Outputable FloatFlag where
ppr :: FloatFlag -> SDoc
ppr FloatFlag
FltLifted = String -> SDoc
text String
"FltLifted"
ppr FloatFlag
FltOkSpec = String -> SDoc
text String
"FltOkSpec"
ppr FloatFlag
FltCareful = String -> SDoc
text String
"FltCareful"
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF FloatFlag
FltCareful FloatFlag
_ = FloatFlag
FltCareful
andFF FloatFlag
FltOkSpec FloatFlag
FltCareful = FloatFlag
FltCareful
andFF FloatFlag
FltOkSpec FloatFlag
_ = FloatFlag
FltOkSpec
andFF FloatFlag
FltLifted FloatFlag
flt = FloatFlag
flt
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
doFloatFromRhs TopLevelFlag
lvl RecFlag
rec Bool
strict_bind (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
fs FloatFlag
ff }) OutExpr
rhs
= Bool -> Bool
not (JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
fs) Bool -> Bool -> Bool
&& Bool
want_to_float Bool -> Bool -> Bool
&& Bool
can_float
where
want_to_float :: Bool
want_to_float = TopLevelFlag -> Bool
isTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
|| OutExpr -> Bool
exprIsCheap OutExpr
rhs Bool -> Bool -> Bool
|| OutExpr -> Bool
exprIsExpandable OutExpr
rhs
can_float :: Bool
can_float = case FloatFlag
ff of
FloatFlag
FltLifted -> Bool
True
FloatFlag
FltOkSpec -> TopLevelFlag -> Bool
isNotTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
&& RecFlag -> Bool
isNonRec RecFlag
rec
FloatFlag
FltCareful -> TopLevelFlag -> Bool
isNotTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
&& RecFlag -> Bool
isNonRec RecFlag
rec Bool -> Bool -> Bool
&& Bool
strict_bind
emptyLetFloats :: LetFloats
emptyLetFloats :: LetFloats
emptyLetFloats = JoinFloats -> FloatFlag -> LetFloats
LetFloats JoinFloats
forall a. OrdList a
nilOL FloatFlag
FltLifted
emptyJoinFloats :: JoinFloats
emptyJoinFloats :: JoinFloats
emptyJoinFloats = JoinFloats
forall a. OrdList a
nilOL
unitLetFloat :: OutBind -> LetFloats
unitLetFloat :: Bind OutId -> LetFloats
unitLetFloat Bind OutId
bind = Bool -> LetFloats -> LetFloats
forall a. HasCallStack => Bool -> a -> a
assert ((OutId -> Bool) -> [OutId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (OutId -> Bool) -> OutId -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutId -> Bool
isJoinId) (Bind OutId -> [OutId]
forall b. Bind b -> [b]
bindersOf Bind OutId
bind)) (LetFloats -> LetFloats) -> LetFloats -> LetFloats
forall a b. (a -> b) -> a -> b
$
JoinFloats -> FloatFlag -> LetFloats
LetFloats (Bind OutId -> JoinFloats
forall a. a -> OrdList a
unitOL Bind OutId
bind) (Bind OutId -> FloatFlag
flag Bind OutId
bind)
where
flag :: Bind OutId -> FloatFlag
flag (Rec {}) = FloatFlag
FltLifted
flag (NonRec OutId
bndr OutExpr
rhs)
| Bool -> Bool
not (OutId -> Bool
isStrictId OutId
bndr) = FloatFlag
FltLifted
| OutExpr -> Bool
exprIsTickedString OutExpr
rhs = FloatFlag
FltLifted
| OutExpr -> Bool
exprOkForSpeculation OutExpr
rhs = FloatFlag
FltOkSpec
| Bool
otherwise = Bool -> SDoc -> FloatFlag -> FloatFlag
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bool -> Bool
not ((() :: Constraint) => Kind -> Bool
Kind -> Bool
isUnliftedType (OutId -> Kind
idType OutId
bndr))) (OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
bndr)
FloatFlag
FltCareful
unitJoinFloat :: OutBind -> JoinFloats
unitJoinFloat :: Bind OutId -> JoinFloats
unitJoinFloat Bind OutId
bind = Bool -> JoinFloats -> JoinFloats
forall a. HasCallStack => Bool -> a -> a
assert ((OutId -> Bool) -> [OutId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all OutId -> Bool
isJoinId (Bind OutId -> [OutId]
forall b. Bind b -> [b]
bindersOf Bind OutId
bind)) (JoinFloats -> JoinFloats) -> JoinFloats -> JoinFloats
forall a b. (a -> b) -> a -> b
$
Bind OutId -> JoinFloats
forall a. a -> OrdList a
unitOL Bind OutId
bind
mkFloatBind :: SimplEnv -> OutBind -> (SimplFloats, SimplEnv)
mkFloatBind :: SimplEnv -> Bind OutId -> (SimplFloats, SimplEnv)
mkFloatBind SimplEnv
env Bind OutId
bind
= (SimplFloats
floats, SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope' })
where
floats :: SimplFloats
floats
| Bind OutId -> Bool
isJoinBind Bind OutId
bind
= SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
emptyLetFloats
, sfJoinFloats :: JoinFloats
sfJoinFloats = Bind OutId -> JoinFloats
unitJoinFloat Bind OutId
bind
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope' }
| Bool
otherwise
= SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = Bind OutId -> LetFloats
unitLetFloat Bind OutId
bind
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope' }
!in_scope' :: InScopeSet
in_scope' = SimplEnv -> InScopeSet
seInScope SimplEnv
env InScopeSet -> Bind OutId -> InScopeSet
`extendInScopeSetBind` Bind OutId
bind
extendFloats :: SimplFloats -> OutBind -> SimplFloats
extendFloats :: SimplFloats -> Bind OutId -> SimplFloats
extendFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
floats
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jfloats
, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
Bind OutId
bind
| Bind OutId -> Bool
isJoinBind Bind OutId
bind
= SimplFloats { sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope'
, sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats' }
| Bool
otherwise
= SimplFloats { sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope'
, sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats'
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats }
where
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> Bind OutId -> InScopeSet
`extendInScopeSetBind` Bind OutId
bind
floats' :: LetFloats
floats' = LetFloats
floats LetFloats -> LetFloats -> LetFloats
`addLetFlts` Bind OutId -> LetFloats
unitLetFloat Bind OutId
bind
jfloats' :: JoinFloats
jfloats' = JoinFloats
jfloats JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` Bind OutId -> JoinFloats
unitJoinFloat Bind OutId
bind
addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
addLetFloats SimplFloats
floats LetFloats
let_floats
= SimplFloats
floats { sfLetFloats :: LetFloats
sfLetFloats = SimplFloats -> LetFloats
sfLetFloats SimplFloats
floats LetFloats -> LetFloats -> LetFloats
`addLetFlts` LetFloats
let_floats
, sfInScope :: InScopeSet
sfInScope = SimplFloats -> InScopeSet
sfInScope SimplFloats
floats InScopeSet -> LetFloats -> InScopeSet
`extendInScopeFromLF` LetFloats
let_floats }
extendInScopeFromLF :: InScopeSet -> LetFloats -> InScopeSet
extendInScopeFromLF :: InScopeSet -> LetFloats -> InScopeSet
extendInScopeFromLF InScopeSet
in_scope (LetFloats JoinFloats
binds FloatFlag
_)
= (InScopeSet -> Bind OutId -> InScopeSet)
-> InScopeSet -> JoinFloats -> InScopeSet
forall b a. (b -> a -> b) -> b -> OrdList a -> b
foldlOL InScopeSet -> Bind OutId -> InScopeSet
extendInScopeSetBind InScopeSet
in_scope JoinFloats
binds
addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
addJoinFloats SimplFloats
floats JoinFloats
join_floats
= SimplFloats
floats { sfJoinFloats :: JoinFloats
sfJoinFloats = SimplFloats -> JoinFloats
sfJoinFloats SimplFloats
floats JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` JoinFloats
join_floats
, sfInScope :: InScopeSet
sfInScope = (InScopeSet -> Bind OutId -> InScopeSet)
-> InScopeSet -> JoinFloats -> InScopeSet
forall b a. (b -> a -> b) -> b -> OrdList a -> b
foldlOL InScopeSet -> Bind OutId -> InScopeSet
extendInScopeSetBind
(SimplFloats -> InScopeSet
sfInScope SimplFloats
floats) JoinFloats
join_floats }
extendInScopeSetBind :: InScopeSet -> CoreBind -> InScopeSet
extendInScopeSetBind :: InScopeSet -> Bind OutId -> InScopeSet
extendInScopeSetBind InScopeSet
in_scope Bind OutId
bind
= InScopeSet -> [OutId] -> InScopeSet
extendInScopeSetList InScopeSet
in_scope (Bind OutId -> [OutId]
forall b. Bind b -> [b]
bindersOf Bind OutId
bind)
addFloats :: SimplFloats -> SimplFloats -> SimplFloats
addFloats :: SimplFloats -> SimplFloats -> SimplFloats
addFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf1, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf1 })
(SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf2, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf2, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
= SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
lf1 LetFloats -> LetFloats -> LetFloats
`addLetFlts` LetFloats
lf2
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jf1 JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` JoinFloats
jf2
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope }
addLetFlts :: LetFloats -> LetFloats -> LetFloats
addLetFlts :: LetFloats -> LetFloats -> LetFloats
addLetFlts (LetFloats JoinFloats
bs1 FloatFlag
l1) (LetFloats JoinFloats
bs2 FloatFlag
l2)
= JoinFloats -> FloatFlag -> LetFloats
LetFloats (JoinFloats
bs1 JoinFloats -> JoinFloats -> JoinFloats
forall a. OrdList a -> OrdList a -> OrdList a
`appOL` JoinFloats
bs2) (FloatFlag
l1 FloatFlag -> FloatFlag -> FloatFlag
`andFF` FloatFlag
l2)
letFloatBinds :: LetFloats -> [CoreBind]
letFloatBinds :: LetFloats -> [Bind OutId]
letFloatBinds (LetFloats JoinFloats
bs FloatFlag
_) = JoinFloats -> [Bind OutId]
forall a. OrdList a -> [a]
fromOL JoinFloats
bs
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts = JoinFloats -> JoinFloats -> JoinFloats
forall a. OrdList a -> OrdList a -> OrdList a
appOL
mkRecFloats :: SimplFloats -> SimplFloats
mkRecFloats :: SimplFloats -> SimplFloats
mkRecFloats floats :: SimplFloats
floats@(SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
bs FloatFlag
_ff
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs
, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
= Bool -> SDoc -> SimplFloats -> SimplFloats
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
bs Bool -> Bool -> Bool
|| JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
jbs) (SimplFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr SimplFloats
floats) (SimplFloats -> SimplFloats) -> SimplFloats -> SimplFloats
forall a b. (a -> b) -> a -> b
$
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats'
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats'
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope }
where
!floats' :: LetFloats
floats' | JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
bs = LetFloats
emptyLetFloats
| Bool
otherwise = Bind OutId -> LetFloats
unitLetFloat ([(OutId, OutExpr)] -> Bind OutId
forall b. [(b, Expr b)] -> Bind b
Rec ([Bind OutId] -> [(OutId, OutExpr)]
forall b. [Bind b] -> [(b, Expr b)]
flattenBinds (JoinFloats -> [Bind OutId]
forall a. OrdList a -> [a]
fromOL JoinFloats
bs)))
!jfloats' :: JoinFloats
jfloats' | JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
jbs = JoinFloats
emptyJoinFloats
| Bool
otherwise = Bind OutId -> JoinFloats
unitJoinFloat ([(OutId, OutExpr)] -> Bind OutId
forall b. [(b, Expr b)] -> Bind b
Rec ([Bind OutId] -> [(OutId, OutExpr)]
forall b. [Bind b] -> [(b, Expr b)]
flattenBinds (JoinFloats -> [Bind OutId]
forall a. OrdList a -> [a]
fromOL JoinFloats
jbs)))
wrapFloats :: SimplFloats -> OutExpr -> OutExpr
wrapFloats :: SimplFloats -> OutExpr -> OutExpr
wrapFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
bs FloatFlag
_
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs }) OutExpr
body
= (Bind OutId -> OutExpr -> OutExpr)
-> OutExpr -> JoinFloats -> OutExpr
forall a b. (a -> b -> b) -> b -> OrdList a -> b
foldrOL Bind OutId -> OutExpr -> OutExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats JoinFloats
jbs OutExpr
body) JoinFloats
bs
wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
wrapJoinFloatsX SimplFloats
floats OutExpr
body
= ( SimplFloats
floats { sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats }
, JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats (SimplFloats -> JoinFloats
sfJoinFloats SimplFloats
floats) OutExpr
body )
wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats JoinFloats
join_floats OutExpr
body
= (Bind OutId -> OutExpr -> OutExpr)
-> OutExpr -> JoinFloats -> OutExpr
forall a b. (a -> b -> b) -> b -> OrdList a -> b
foldrOL Bind OutId -> OutExpr -> OutExpr
forall b. Bind b -> Expr b -> Expr b
Let OutExpr
body JoinFloats
join_floats
getTopFloatBinds :: SimplFloats -> [CoreBind]
getTopFloatBinds :: SimplFloats -> [Bind OutId]
getTopFloatBinds (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lbs
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs})
= Bool -> [Bind OutId] -> [Bind OutId]
forall a. HasCallStack => Bool -> a -> a
assert (JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
jbs) ([Bind OutId] -> [Bind OutId]) -> [Bind OutId] -> [Bind OutId]
forall a b. (a -> b) -> a -> b
$
LetFloats -> [Bind OutId]
letFloatBinds LetFloats
lbs
{-# INLINE mapLetFloats #-}
mapLetFloats :: LetFloats -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> LetFloats
mapLetFloats :: LetFloats -> ((OutId, OutExpr) -> (OutId, OutExpr)) -> LetFloats
mapLetFloats (LetFloats JoinFloats
fs FloatFlag
ff) (OutId, OutExpr) -> (OutId, OutExpr)
fun
= JoinFloats -> FloatFlag -> LetFloats
LetFloats JoinFloats
fs1 FloatFlag
ff
where
app :: Bind OutId -> Bind OutId
app (NonRec OutId
b OutExpr
e) = case (OutId, OutExpr) -> (OutId, OutExpr)
fun (OutId
b,OutExpr
e) of (OutId
b',OutExpr
e') -> OutId -> OutExpr -> Bind OutId
forall b. b -> Expr b -> Bind b
NonRec OutId
b' OutExpr
e'
app (Rec [(OutId, OutExpr)]
bs) = [(OutId, OutExpr)] -> Bind OutId
forall b. [(b, Expr b)] -> Bind b
Rec (((OutId, OutExpr) -> (OutId, OutExpr))
-> [(OutId, OutExpr)] -> [(OutId, OutExpr)]
forall a b. (a -> b) -> [a] -> [b]
strictMap (OutId, OutExpr) -> (OutId, OutExpr)
fun [(OutId, OutExpr)]
bs)
!fs1 :: JoinFloats
fs1 = ((Bind OutId -> Bind OutId) -> JoinFloats -> JoinFloats
forall a b. (a -> b) -> OrdList a -> OrdList b
mapOL' Bind OutId -> Bind OutId
app JoinFloats
fs)
substId :: SimplEnv -> InId -> SimplSR
substId :: SimplEnv -> OutId -> SimplSR
substId (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) OutId
v
= case SimplIdSubst -> OutId -> Maybe SimplSR
forall a. VarEnv a -> OutId -> Maybe a
lookupVarEnv SimplIdSubst
ids OutId
v of
Maybe SimplSR
Nothing -> OutId -> SimplSR
DoneId (InScopeSet -> OutId -> OutId
refineFromInScope InScopeSet
in_scope OutId
v)
Just (DoneId OutId
v) -> OutId -> SimplSR
DoneId (InScopeSet -> OutId -> OutId
refineFromInScope InScopeSet
in_scope OutId
v)
Just SimplSR
res -> SimplSR
res
refineFromInScope :: InScopeSet -> Var -> Var
refineFromInScope :: InScopeSet -> OutId -> OutId
refineFromInScope InScopeSet
in_scope OutId
v
| OutId -> Bool
isLocalId OutId
v = case InScopeSet -> OutId -> Maybe OutId
lookupInScope InScopeSet
in_scope OutId
v of
Just OutId
v' -> OutId
v'
Maybe OutId
Nothing -> String -> SDoc -> OutId
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"refineFromInScope" (InScopeSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr InScopeSet
in_scope SDoc -> SDoc -> SDoc
$$ OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
v)
| Bool
otherwise = OutId
v
lookupRecBndr :: SimplEnv -> InId -> OutId
lookupRecBndr :: SimplEnv -> OutId -> OutId
lookupRecBndr (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) OutId
v
= case SimplIdSubst -> OutId -> Maybe SimplSR
forall a. VarEnv a -> OutId -> Maybe a
lookupVarEnv SimplIdSubst
ids OutId
v of
Just (DoneId OutId
v) -> OutId
v
Just SimplSR
_ -> String -> SDoc -> OutId
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"lookupRecBndr" (OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
v)
Maybe SimplSR
Nothing -> InScopeSet -> OutId -> OutId
refineFromInScope InScopeSet
in_scope OutId
v
simplBinders :: SimplEnv -> [InBndr] -> SimplM (SimplEnv, [OutBndr])
simplBinders :: SimplEnv -> [OutId] -> SimplM (SimplEnv, [OutId])
simplBinders !SimplEnv
env [OutId]
bndrs = (SimplEnv -> OutId -> SimplM (SimplEnv, OutId))
-> SimplEnv -> [OutId] -> SimplM (SimplEnv, [OutId])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SimplEnv -> OutId -> SimplM (SimplEnv, OutId)
simplBinder SimplEnv
env [OutId]
bndrs
simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder :: SimplEnv -> OutId -> SimplM (SimplEnv, OutId)
simplBinder !SimplEnv
env OutId
bndr
| OutId -> Bool
isTyVar OutId
bndr = do { let (SimplEnv
env', OutId
tv) = SimplEnv -> OutId -> (SimplEnv, OutId)
substTyVarBndr SimplEnv
env OutId
bndr
; OutId -> ()
seqTyVar OutId
tv () -> SimplM (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a b. a -> b -> b
`seq` (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env', OutId
tv) }
| Bool
otherwise = do { let (SimplEnv
env', OutId
id) = SimplEnv -> OutId -> (SimplEnv, OutId)
substIdBndr SimplEnv
env OutId
bndr
; OutId -> ()
seqId OutId
id () -> SimplM (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a b. a -> b -> b
`seq` (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env', OutId
id) }
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplNonRecBndr :: SimplEnv -> OutId -> SimplM (SimplEnv, OutId)
simplNonRecBndr !SimplEnv
env OutId
id
= do { let (!SimplEnv
env1, OutId
id1) = SimplEnv -> OutId -> (SimplEnv, OutId)
substIdBndr SimplEnv
env OutId
id
; OutId -> ()
seqId OutId
id1 () -> SimplM (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a b. a -> b -> b
`seq` (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env1, OutId
id1) }
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
simplRecBndrs :: SimplEnv -> [OutId] -> SimplM SimplEnv
simplRecBndrs env :: SimplEnv
env@(SimplEnv {}) [OutId]
ids
= Bool -> SimplM SimplEnv -> SimplM SimplEnv
forall a. HasCallStack => Bool -> a -> a
assert ((OutId -> Bool) -> [OutId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (OutId -> Bool) -> OutId -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutId -> Bool
isJoinId) [OutId]
ids) (SimplM SimplEnv -> SimplM SimplEnv)
-> SimplM SimplEnv -> SimplM SimplEnv
forall a b. (a -> b) -> a -> b
$
do { let (!SimplEnv
env1, [OutId]
ids1) = (SimplEnv -> OutId -> (SimplEnv, OutId))
-> SimplEnv -> [OutId] -> (SimplEnv, [OutId])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL SimplEnv -> OutId -> (SimplEnv, OutId)
substIdBndr SimplEnv
env [OutId]
ids
; [OutId] -> ()
seqIds [OutId]
ids1 () -> SimplM SimplEnv -> SimplM SimplEnv
forall a b. a -> b -> b
`seq` SimplEnv -> SimplM SimplEnv
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return SimplEnv
env1 }
substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
substIdBndr :: SimplEnv -> OutId -> (SimplEnv, OutId)
substIdBndr SimplEnv
env OutId
bndr
| OutId -> Bool
isCoVar OutId
bndr = SimplEnv -> OutId -> (SimplEnv, OutId)
substCoVarBndr SimplEnv
env OutId
bndr
| Bool
otherwise = SimplEnv -> OutId -> (SimplEnv, OutId)
substNonCoVarIdBndr SimplEnv
env OutId
bndr
substNonCoVarIdBndr
:: SimplEnv
-> InBndr
-> (SimplEnv, OutBndr)
substNonCoVarIdBndr :: SimplEnv -> OutId -> (SimplEnv, OutId)
substNonCoVarIdBndr SimplEnv
env OutId
id = SimplEnv -> OutId -> (OutId -> OutId) -> (SimplEnv, OutId)
subst_id_bndr SimplEnv
env OutId
id (\OutId
x -> OutId
x)
{-# INLINE subst_id_bndr #-}
subst_id_bndr :: SimplEnv
-> InBndr
-> (OutId -> OutId)
-> (SimplEnv, OutBndr)
subst_id_bndr :: SimplEnv -> OutId -> (OutId -> OutId) -> (SimplEnv, OutId)
subst_id_bndr env :: SimplEnv
env@(SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
id_subst })
OutId
old_id OutId -> OutId
adjust_type
= Bool -> SDoc -> (SimplEnv, OutId) -> (SimplEnv, OutId)
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bool -> Bool
not (OutId -> Bool
isCoVar OutId
old_id)) (OutId -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutId
old_id)
(SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
new_in_scope,
seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
new_subst }, OutId
new_id)
where
!id1 :: OutId
id1 = InScopeSet -> OutId -> OutId
uniqAway InScopeSet
in_scope OutId
old_id
!id2 :: OutId
id2 = SimplEnv -> OutId -> OutId
substIdType SimplEnv
env OutId
id1
!id3 :: OutId
id3 = OutId -> OutId
zapFragileIdInfo OutId
id2
!new_id :: OutId
new_id = OutId -> OutId
adjust_type OutId
id3
!new_subst :: SimplIdSubst
new_subst | OutId
new_id OutId -> OutId -> Bool
forall a. Eq a => a -> a -> Bool
/= OutId
old_id
= SimplIdSubst -> OutId -> SimplSR -> SimplIdSubst
forall a. VarEnv a -> OutId -> a -> VarEnv a
extendVarEnv SimplIdSubst
id_subst OutId
old_id (OutId -> SimplSR
DoneId OutId
new_id)
| Bool
otherwise
= SimplIdSubst -> OutId -> SimplIdSubst
forall a. VarEnv a -> OutId -> VarEnv a
delVarEnv SimplIdSubst
id_subst OutId
old_id
!new_in_scope :: InScopeSet
new_in_scope = InScopeSet
in_scope InScopeSet -> OutId -> InScopeSet
`extendInScopeSet` OutId
new_id
seqTyVar :: TyVar -> ()
seqTyVar :: OutId -> ()
seqTyVar OutId
b = OutId
b OutId -> () -> ()
forall a b. a -> b -> b
`seq` ()
seqId :: Id -> ()
seqId :: OutId -> ()
seqId OutId
id = Kind -> ()
seqType (OutId -> Kind
idType OutId
id) () -> () -> ()
forall a b. a -> b -> b
`seq`
(() :: Constraint) => OutId -> IdInfo
OutId -> IdInfo
idInfo OutId
id IdInfo -> () -> ()
forall a b. a -> b -> b
`seq`
()
seqIds :: [Id] -> ()
seqIds :: [OutId] -> ()
seqIds [] = ()
seqIds (OutId
id:[OutId]
ids) = OutId -> ()
seqId OutId
id () -> () -> ()
forall a b. a -> b -> b
`seq` [OutId] -> ()
seqIds [OutId]
ids
simplNonRecJoinBndr :: SimplEnv -> InBndr
-> Mult -> OutType
-> SimplM (SimplEnv, OutBndr)
simplNonRecJoinBndr :: SimplEnv -> OutId -> Kind -> Kind -> SimplM (SimplEnv, OutId)
simplNonRecJoinBndr SimplEnv
env OutId
id Kind
mult Kind
res_ty
= do { let (SimplEnv
env1, OutId
id1) = Kind -> Kind -> SimplEnv -> OutId -> (SimplEnv, OutId)
simplJoinBndr Kind
mult Kind
res_ty SimplEnv
env OutId
id
; OutId -> ()
seqId OutId
id1 () -> SimplM (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a b. a -> b -> b
`seq` (SimplEnv, OutId) -> SimplM (SimplEnv, OutId)
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env1, OutId
id1) }
simplRecJoinBndrs :: SimplEnv -> [InBndr]
-> Mult -> OutType
-> SimplM SimplEnv
simplRecJoinBndrs :: SimplEnv -> [OutId] -> Kind -> Kind -> SimplM SimplEnv
simplRecJoinBndrs env :: SimplEnv
env@(SimplEnv {}) [OutId]
ids Kind
mult Kind
res_ty
= Bool -> SimplM SimplEnv -> SimplM SimplEnv
forall a. HasCallStack => Bool -> a -> a
assert ((OutId -> Bool) -> [OutId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all OutId -> Bool
isJoinId [OutId]
ids) (SimplM SimplEnv -> SimplM SimplEnv)
-> SimplM SimplEnv -> SimplM SimplEnv
forall a b. (a -> b) -> a -> b
$
do { let (SimplEnv
env1, [OutId]
ids1) = (SimplEnv -> OutId -> (SimplEnv, OutId))
-> SimplEnv -> [OutId] -> (SimplEnv, [OutId])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Kind -> Kind -> SimplEnv -> OutId -> (SimplEnv, OutId)
simplJoinBndr Kind
mult Kind
res_ty) SimplEnv
env [OutId]
ids
; [OutId] -> ()
seqIds [OutId]
ids1 () -> SimplM SimplEnv -> SimplM SimplEnv
forall a b. a -> b -> b
`seq` SimplEnv -> SimplM SimplEnv
forall a. a -> SimplM a
forall (m :: * -> *) a. Monad m => a -> m a
return SimplEnv
env1 }
simplJoinBndr :: Mult -> OutType
-> SimplEnv -> InBndr
-> (SimplEnv, OutBndr)
simplJoinBndr :: Kind -> Kind -> SimplEnv -> OutId -> (SimplEnv, OutId)
simplJoinBndr Kind
mult Kind
res_ty SimplEnv
env OutId
id
= SimplEnv -> OutId -> (OutId -> OutId) -> (SimplEnv, OutId)
subst_id_bndr SimplEnv
env OutId
id (Kind -> Kind -> OutId -> OutId
adjustJoinPointType Kind
mult Kind
res_ty)
adjustJoinPointType :: Mult
-> Type
-> Id
-> Id
adjustJoinPointType :: Kind -> Kind -> OutId -> OutId
adjustJoinPointType Kind
mult Kind
new_res_ty OutId
join_id
= Bool -> OutId -> OutId
forall a. HasCallStack => Bool -> a -> a
assert (OutId -> Bool
isJoinId OutId
join_id) (OutId -> OutId) -> OutId -> OutId
forall a b. (a -> b) -> a -> b
$
OutId -> Kind -> OutId
setIdType OutId
join_id Kind
new_join_ty
where
orig_ar :: JoinArity
orig_ar = OutId -> JoinArity
idJoinArity OutId
join_id
orig_ty :: Kind
orig_ty = OutId -> Kind
idType OutId
join_id
new_join_ty :: Kind
new_join_ty = JoinArity -> Kind -> Kind
go JoinArity
orig_ar Kind
orig_ty :: Type
go :: JoinArity -> Kind -> Kind
go JoinArity
0 Kind
_ = Kind
new_res_ty
go JoinArity
n Kind
ty | Just (TyCoBinder
arg_bndr, Kind
res_ty) <- Kind -> Maybe (TyCoBinder, Kind)
splitPiTy_maybe Kind
ty
= TyCoBinder -> Kind -> Kind
mkPiTy (TyCoBinder -> TyCoBinder
scale_bndr TyCoBinder
arg_bndr) (Kind -> Kind) -> Kind -> Kind
forall a b. (a -> b) -> a -> b
$
JoinArity -> Kind -> Kind
go (JoinArity
nJoinArity -> JoinArity -> JoinArity
forall a. Num a => a -> a -> a
-JoinArity
1) Kind
res_ty
| Bool
otherwise
= String -> SDoc -> Kind
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"adjustJoinPointType" (JoinArity -> SDoc
forall a. Outputable a => a -> SDoc
ppr JoinArity
orig_ar SDoc -> SDoc -> SDoc
<+> Kind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Kind
orig_ty)
scale_bndr :: TyCoBinder -> TyCoBinder
scale_bndr (Anon AnonArgFlag
af Scaled Kind
t) = AnonArgFlag -> Scaled Kind -> TyCoBinder
Anon AnonArgFlag
af (Scaled Kind -> TyCoBinder) -> Scaled Kind -> TyCoBinder
forall a b. (a -> b) -> a -> b
$! (Kind -> Scaled Kind -> Scaled Kind
forall a. Kind -> Scaled a -> Scaled a
scaleScaled Kind
mult Scaled Kind
t)
scale_bndr b :: TyCoBinder
b@(Named TyCoVarBinder
_) = TyCoBinder
b
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tv_env
, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cv_env })
= InScopeSet -> (TvSubstEnv, CvSubstEnv) -> TCvSubst
mkTCvSubst InScopeSet
in_scope (TvSubstEnv
tv_env, CvSubstEnv
cv_env)
substTy :: SimplEnv -> Type -> Type
substTy :: SimplEnv -> Kind -> Kind
substTy SimplEnv
env Kind
ty = (() :: Constraint) => TCvSubst -> Kind -> Kind
TCvSubst -> Kind -> Kind
Type.substTy (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) Kind
ty
substTyVar :: SimplEnv -> TyVar -> Type
substTyVar :: SimplEnv -> OutId -> Kind
substTyVar SimplEnv
env OutId
tv = TCvSubst -> OutId -> Kind
Type.substTyVar (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) OutId
tv
substTyVarBndr :: SimplEnv -> TyVar -> (SimplEnv, TyVar)
substTyVarBndr :: SimplEnv -> OutId -> (SimplEnv, OutId)
substTyVarBndr SimplEnv
env OutId
tv
= case (() :: Constraint) => TCvSubst -> OutId -> (TCvSubst, OutId)
TCvSubst -> OutId -> (TCvSubst, OutId)
Type.substTyVarBndr (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) OutId
tv of
(TCvSubst InScopeSet
in_scope' TvSubstEnv
tv_env' CvSubstEnv
cv_env', OutId
tv')
-> (SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope', seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tv_env', seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cv_env' }, OutId
tv')
substCoVar :: SimplEnv -> CoVar -> Coercion
substCoVar :: SimplEnv -> OutId -> Coercion
substCoVar SimplEnv
env OutId
tv = TCvSubst -> OutId -> Coercion
Coercion.substCoVar (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) OutId
tv
substCoVarBndr :: SimplEnv -> CoVar -> (SimplEnv, CoVar)
substCoVarBndr :: SimplEnv -> OutId -> (SimplEnv, OutId)
substCoVarBndr SimplEnv
env OutId
cv
= case (() :: Constraint) => TCvSubst -> OutId -> (TCvSubst, OutId)
TCvSubst -> OutId -> (TCvSubst, OutId)
Coercion.substCoVarBndr (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) OutId
cv of
(TCvSubst InScopeSet
in_scope' TvSubstEnv
tv_env' CvSubstEnv
cv_env', OutId
cv')
-> (SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope', seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tv_env', seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cv_env' }, OutId
cv')
substCo :: SimplEnv -> Coercion -> Coercion
substCo :: SimplEnv -> Coercion -> Coercion
substCo SimplEnv
env Coercion
co = (() :: Constraint) => TCvSubst -> Coercion -> Coercion
TCvSubst -> Coercion -> Coercion
Coercion.substCo (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) Coercion
co
substIdType :: SimplEnv -> Id -> Id
substIdType :: SimplEnv -> OutId -> OutId
substIdType (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tv_env, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cv_env }) OutId
id
| (TvSubstEnv -> Bool
forall a. VarEnv a -> Bool
isEmptyVarEnv TvSubstEnv
tv_env Bool -> Bool -> Bool
&& CvSubstEnv -> Bool
forall a. VarEnv a -> Bool
isEmptyVarEnv CvSubstEnv
cv_env)
Bool -> Bool -> Bool
|| Bool
no_free_vars
= OutId
id
| Bool
otherwise = (Kind -> Kind) -> OutId -> OutId
Id.updateIdTypeAndMult (TCvSubst -> Kind -> Kind
Type.substTyUnchecked TCvSubst
subst) OutId
id
where
no_free_vars :: Bool
no_free_vars = Kind -> Bool
noFreeVarsOfType Kind
old_ty Bool -> Bool -> Bool
&& Kind -> Bool
noFreeVarsOfType Kind
old_w
subst :: TCvSubst
subst = InScopeSet -> TvSubstEnv -> CvSubstEnv -> TCvSubst
TCvSubst InScopeSet
in_scope TvSubstEnv
tv_env CvSubstEnv
cv_env
old_ty :: Kind
old_ty = OutId -> Kind
idType OutId
id
old_w :: Kind
old_w = OutId -> Kind
varMult OutId
id