{-# LANGUAGE CPP #-}
module GHC.Core.Opt.Simplify.Env (
setMode, getMode, updMode, seDynFlags,
SimplEnv(..), pprSimplEnv,
mkSimplEnv, extendIdSubst,
extendTvSubst, extendCvSubst,
zapSubstEnv, setSubstEnv,
getInScope, setInScopeFromE, setInScopeFromF,
setInScopeSet, modifyInScope, addNewInScopeIds,
getSimplRules,
SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
simplBinder, simplBinders,
substTy, substTyVar, getTCvSubst,
substCo, substCoVar,
SimplFloats(..), emptyFloats, mkRecFloats,
mkFloatBind, addLetFloats, addJoinFloats, addFloats,
extendFloats, wrapFloats,
doFloatFromRhs, getTopFloatBinds,
LetFloats, letFloatBinds, emptyLetFloats, unitLetFloat,
addLetFlts, mapLetFloats,
JoinFloat, JoinFloats, emptyJoinFloats,
wrapJoinFloats, wrapJoinFloatsX, unitJoinFloat, addJoinFlts
) where
#include "HsVersions.h"
import GHC.Prelude
import GHC.Core.Opt.Simplify.Monad
import GHC.Core.Opt.Monad ( SimplMode(..) )
import GHC.Core
import GHC.Core.Utils
import GHC.Core.Multiplicity ( scaleScaled )
import GHC.Types.Var
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Data.OrdList
import GHC.Types.Id as Id
import GHC.Core.Make ( mkWildValBinder )
import GHC.Driver.Session ( DynFlags )
import GHC.Builtin.Types
import GHC.Core.TyCo.Rep ( TyCoBinder(..) )
import qualified GHC.Core.Type as Type
import GHC.Core.Type hiding ( substTy, substTyVar, substTyVarBndr, extendTvSubst, extendCvSubst )
import qualified GHC.Core.Coercion as Coercion
import GHC.Core.Coercion hiding ( substCo, substCoVar, substCoVarBndr )
import GHC.Types.Basic
import GHC.Utils.Monad
import GHC.Utils.Outputable
import GHC.Utils.Misc
import GHC.Types.Unique.FM ( pprUniqFM )
import Data.List (mapAccumL)
data SimplEnv
= SimplEnv {
SimplEnv -> SimplMode
seMode :: SimplMode
, SimplEnv -> TvSubstEnv
seTvSubst :: TvSubstEnv
, SimplEnv -> CvSubstEnv
seCvSubst :: CvSubstEnv
, SimplEnv -> SimplIdSubst
seIdSubst :: SimplIdSubst
, SimplEnv -> InScopeSet
seInScope :: InScopeSet
}
data SimplFloats
= SimplFloats
{
SimplFloats -> LetFloats
sfLetFloats :: LetFloats
, SimplFloats -> JoinFloats
sfJoinFloats :: JoinFloats
, SimplFloats -> InScopeSet
sfInScope :: InScopeSet
}
instance Outputable SimplFloats where
ppr :: SimplFloats -> SDoc
ppr (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
is })
= String -> SDoc
text String
"SimplFloats"
SDoc -> SDoc -> SDoc
<+> SDoc -> SDoc
braces ([SDoc] -> SDoc
vcat [ String -> SDoc
text String
"lets: " SDoc -> SDoc -> SDoc
<+> LetFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr LetFloats
lf
, String -> SDoc
text String
"joins:" SDoc -> SDoc -> SDoc
<+> JoinFloats -> SDoc
forall a. Outputable a => a -> SDoc
ppr JoinFloats
jf
, String -> SDoc
text String
"in_scope:" SDoc -> SDoc -> SDoc
<+> InScopeSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr InScopeSet
is ])
emptyFloats :: SimplEnv -> SimplFloats
emptyFloats :: SimplEnv -> SimplFloats
emptyFloats SimplEnv
env
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
emptyLetFloats
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats
, sfInScope :: InScopeSet
sfInScope = SimplEnv -> InScopeSet
seInScope SimplEnv
env }
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv SimplEnv
env
= [SDoc] -> SDoc
vcat [String -> SDoc
text String
"TvSubst:" SDoc -> SDoc -> SDoc
<+> TvSubstEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> TvSubstEnv
seTvSubst SimplEnv
env),
String -> SDoc
text String
"CvSubst:" SDoc -> SDoc -> SDoc
<+> CvSubstEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> CvSubstEnv
seCvSubst SimplEnv
env),
String -> SDoc
text String
"IdSubst:" SDoc -> SDoc -> SDoc
<+> SDoc
id_subst_doc,
String -> SDoc
text String
"InScope:" SDoc -> SDoc -> SDoc
<+> SDoc
in_scope_vars_doc
]
where
id_subst_doc :: SDoc
id_subst_doc = (SimplSR -> SDoc) -> SimplIdSubst -> SDoc
forall a key. (a -> SDoc) -> UniqFM key a -> SDoc
pprUniqFM SimplSR -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SimplEnv -> SimplIdSubst
seIdSubst SimplEnv
env)
in_scope_vars_doc :: SDoc
in_scope_vars_doc = VarSet -> ([InBndr] -> SDoc) -> SDoc
pprVarSet (InScopeSet -> VarSet
getInScopeVars (SimplEnv -> InScopeSet
seInScope SimplEnv
env))
([SDoc] -> SDoc
vcat ([SDoc] -> SDoc) -> ([InBndr] -> [SDoc]) -> [InBndr] -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (InBndr -> SDoc) -> [InBndr] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map InBndr -> SDoc
ppr_one)
ppr_one :: InBndr -> SDoc
ppr_one InBndr
v | InBndr -> Bool
isId InBndr
v = InBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr InBndr
v SDoc -> SDoc -> SDoc
<+> Unfolding -> SDoc
forall a. Outputable a => a -> SDoc
ppr (InBndr -> Unfolding
idUnfolding InBndr
v)
| Bool
otherwise = InBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr InBndr
v
type SimplIdSubst = IdEnv SimplSR
data SimplSR
= DoneEx OutExpr (Maybe JoinArity)
| DoneId OutId
| ContEx TvSubstEnv
CvSubstEnv
SimplIdSubst
InExpr
instance Outputable SimplSR where
ppr :: SimplSR -> SDoc
ppr (DoneId InBndr
v) = String -> SDoc
text String
"DoneId" SDoc -> SDoc -> SDoc
<+> InBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr InBndr
v
ppr (DoneEx OutExpr
e Maybe JoinArity
mj) = String -> SDoc
text String
"DoneEx" SDoc -> SDoc -> SDoc
<> SDoc
pp_mj SDoc -> SDoc -> SDoc
<+> OutExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutExpr
e
where
pp_mj :: SDoc
pp_mj = case Maybe JoinArity
mj of
Maybe JoinArity
Nothing -> SDoc
empty
Just JoinArity
n -> SDoc -> SDoc
parens (JoinArity -> SDoc
int JoinArity
n)
ppr (ContEx TvSubstEnv
_tv CvSubstEnv
_cv SimplIdSubst
_id OutExpr
e) = [SDoc] -> SDoc
vcat [String -> SDoc
text String
"ContEx" SDoc -> SDoc -> SDoc
<+> OutExpr -> SDoc
forall a. Outputable a => a -> SDoc
ppr OutExpr
e ]
mkSimplEnv :: SimplMode -> SimplEnv
mkSimplEnv :: SimplMode -> SimplEnv
mkSimplEnv SimplMode
mode
= SimplEnv :: SimplMode
-> TvSubstEnv
-> CvSubstEnv
-> SimplIdSubst
-> InScopeSet
-> SimplEnv
SimplEnv { seMode :: SimplMode
seMode = SimplMode
mode
, seInScope :: InScopeSet
seInScope = InScopeSet
init_in_scope
, seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
forall a. VarEnv a
emptyVarEnv
, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
forall a. VarEnv a
emptyVarEnv
, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
forall a. VarEnv a
emptyVarEnv }
init_in_scope :: InScopeSet
init_in_scope :: InScopeSet
init_in_scope = VarSet -> InScopeSet
mkInScopeSet (InBndr -> VarSet
unitVarSet (Type -> Type -> InBndr
mkWildValBinder Type
Many Type
unitTy))
getMode :: SimplEnv -> SimplMode
getMode :: SimplEnv -> SimplMode
getMode SimplEnv
env = SimplEnv -> SimplMode
seMode SimplEnv
env
seDynFlags :: SimplEnv -> DynFlags
seDynFlags :: SimplEnv -> DynFlags
seDynFlags SimplEnv
env = SimplMode -> DynFlags
sm_dflags (SimplEnv -> SimplMode
seMode SimplEnv
env)
setMode :: SimplMode -> SimplEnv -> SimplEnv
setMode :: SimplMode -> SimplEnv -> SimplEnv
setMode SimplMode
mode SimplEnv
env = SimplEnv
env { seMode :: SimplMode
seMode = SimplMode
mode }
updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
updMode SimplMode -> SimplMode
upd SimplEnv
env = SimplEnv
env { seMode :: SimplMode
seMode = SimplMode -> SimplMode
upd (SimplEnv -> SimplMode
seMode SimplEnv
env) }
extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
extendIdSubst :: SimplEnv -> InBndr -> SimplSR -> SimplEnv
extendIdSubst env :: SimplEnv
env@(SimplEnv {seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
subst}) InBndr
var SimplSR
res
= ASSERT2( isId var && not (isCoVar var), ppr var )
SimplEnv
env { seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst -> InBndr -> SimplSR -> SimplIdSubst
forall a. VarEnv a -> InBndr -> a -> VarEnv a
extendVarEnv SimplIdSubst
subst InBndr
var SimplSR
res }
extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
extendTvSubst :: SimplEnv -> InBndr -> Type -> SimplEnv
extendTvSubst env :: SimplEnv
env@(SimplEnv {seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tsubst}) InBndr
var Type
res
= ASSERT2( isTyVar var, ppr var $$ ppr res )
SimplEnv
env {seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv -> InBndr -> Type -> TvSubstEnv
forall a. VarEnv a -> InBndr -> a -> VarEnv a
extendVarEnv TvSubstEnv
tsubst InBndr
var Type
res}
extendCvSubst :: SimplEnv -> CoVar -> Coercion -> SimplEnv
extendCvSubst :: SimplEnv -> InBndr -> Coercion -> SimplEnv
extendCvSubst env :: SimplEnv
env@(SimplEnv {seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
csubst}) InBndr
var Coercion
co
= ASSERT( isCoVar var )
SimplEnv
env {seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv -> InBndr -> Coercion -> CvSubstEnv
forall a. VarEnv a -> InBndr -> a -> VarEnv a
extendVarEnv CvSubstEnv
csubst InBndr
var Coercion
co}
getInScope :: SimplEnv -> InScopeSet
getInScope :: SimplEnv -> InScopeSet
getInScope SimplEnv
env = SimplEnv -> InScopeSet
seInScope SimplEnv
env
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet SimplEnv
env InScopeSet
in_scope = SimplEnv
env {seInScope :: InScopeSet
seInScope = InScopeSet
in_scope}
setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeFromE SimplEnv
rhs_env SimplEnv
here_env = SimplEnv
rhs_env { seInScope :: InScopeSet
seInScope = SimplEnv -> InScopeSet
seInScope SimplEnv
here_env }
setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
setInScopeFromF SimplEnv
env SimplFloats
floats = SimplEnv
env { seInScope :: InScopeSet
seInScope = SimplFloats -> InScopeSet
sfInScope SimplFloats
floats }
addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
addNewInScopeIds :: SimplEnv -> [InBndr] -> SimplEnv
addNewInScopeIds env :: SimplEnv
env@(SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
id_subst }) [InBndr]
vs
= SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope InScopeSet -> [InBndr] -> InScopeSet
`extendInScopeSetList` [InBndr]
vs,
seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
id_subst SimplIdSubst -> [InBndr] -> SimplIdSubst
forall a. VarEnv a -> [InBndr] -> VarEnv a
`delVarEnvList` [InBndr]
vs }
modifyInScope :: SimplEnv -> CoreBndr -> SimplEnv
modifyInScope :: SimplEnv -> InBndr -> SimplEnv
modifyInScope env :: SimplEnv
env@(SimplEnv {seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope}) InBndr
v
= SimplEnv
env {seInScope :: InScopeSet
seInScope = InScopeSet -> InBndr -> InScopeSet
extendInScopeSet InScopeSet
in_scope InBndr
v}
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv SimplEnv
env = SimplEnv
env {seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
forall a. VarEnv a
emptyVarEnv, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
forall a. VarEnv a
emptyVarEnv, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
forall a. VarEnv a
emptyVarEnv}
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv SimplEnv
env TvSubstEnv
tvs CvSubstEnv
cvs SimplIdSubst
ids = SimplEnv
env { seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tvs, seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cvs, seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
ids }
mkContEx :: SimplEnv -> InExpr -> SimplSR
mkContEx :: SimplEnv -> OutExpr -> SimplSR
mkContEx (SimplEnv { seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tvs, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cvs, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) OutExpr
e = TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> OutExpr -> SimplSR
ContEx TvSubstEnv
tvs CvSubstEnv
cvs SimplIdSubst
ids OutExpr
e
data LetFloats = LetFloats (OrdList OutBind) FloatFlag
type JoinFloat = OutBind
type JoinFloats = OrdList JoinFloat
data FloatFlag
= FltLifted
| FltOkSpec
| FltCareful
instance Outputable LetFloats where
ppr :: LetFloats -> SDoc
ppr (LetFloats JoinFloats
binds FloatFlag
ff) = FloatFlag -> SDoc
forall a. Outputable a => a -> SDoc
ppr FloatFlag
ff SDoc -> SDoc -> SDoc
$$ [OutBind] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (JoinFloats -> [OutBind]
forall a. OrdList a -> [a]
fromOL JoinFloats
binds)
instance Outputable FloatFlag where
ppr :: FloatFlag -> SDoc
ppr FloatFlag
FltLifted = String -> SDoc
text String
"FltLifted"
ppr FloatFlag
FltOkSpec = String -> SDoc
text String
"FltOkSpec"
ppr FloatFlag
FltCareful = String -> SDoc
text String
"FltCareful"
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF FloatFlag
FltCareful FloatFlag
_ = FloatFlag
FltCareful
andFF FloatFlag
FltOkSpec FloatFlag
FltCareful = FloatFlag
FltCareful
andFF FloatFlag
FltOkSpec FloatFlag
_ = FloatFlag
FltOkSpec
andFF FloatFlag
FltLifted FloatFlag
flt = FloatFlag
flt
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
doFloatFromRhs TopLevelFlag
lvl RecFlag
rec Bool
str (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
fs FloatFlag
ff }) OutExpr
rhs
= Bool -> Bool
not (JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
fs) Bool -> Bool -> Bool
&& Bool
want_to_float Bool -> Bool -> Bool
&& Bool
can_float
where
want_to_float :: Bool
want_to_float = TopLevelFlag -> Bool
isTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
|| OutExpr -> Bool
exprIsCheap OutExpr
rhs Bool -> Bool -> Bool
|| OutExpr -> Bool
exprIsExpandable OutExpr
rhs
can_float :: Bool
can_float = case FloatFlag
ff of
FloatFlag
FltLifted -> Bool
True
FloatFlag
FltOkSpec -> TopLevelFlag -> Bool
isNotTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
&& RecFlag -> Bool
isNonRec RecFlag
rec
FloatFlag
FltCareful -> TopLevelFlag -> Bool
isNotTopLevel TopLevelFlag
lvl Bool -> Bool -> Bool
&& RecFlag -> Bool
isNonRec RecFlag
rec Bool -> Bool -> Bool
&& Bool
str
emptyLetFloats :: LetFloats
emptyLetFloats :: LetFloats
emptyLetFloats = JoinFloats -> FloatFlag -> LetFloats
LetFloats JoinFloats
forall a. OrdList a
nilOL FloatFlag
FltLifted
emptyJoinFloats :: JoinFloats
emptyJoinFloats :: JoinFloats
emptyJoinFloats = JoinFloats
forall a. OrdList a
nilOL
unitLetFloat :: OutBind -> LetFloats
unitLetFloat :: OutBind -> LetFloats
unitLetFloat OutBind
bind = ASSERT(all (not . isJoinId) (bindersOf bind))
JoinFloats -> FloatFlag -> LetFloats
LetFloats (OutBind -> JoinFloats
forall a. a -> OrdList a
unitOL OutBind
bind) (OutBind -> FloatFlag
flag OutBind
bind)
where
flag :: OutBind -> FloatFlag
flag (Rec {}) = FloatFlag
FltLifted
flag (NonRec InBndr
bndr OutExpr
rhs)
| Bool -> Bool
not (InBndr -> Bool
isStrictId InBndr
bndr) = FloatFlag
FltLifted
| OutExpr -> Bool
exprIsTickedString OutExpr
rhs = FloatFlag
FltLifted
| OutExpr -> Bool
exprOkForSpeculation OutExpr
rhs = FloatFlag
FltOkSpec
| Bool
otherwise = ASSERT2( not (isUnliftedType (idType bndr)), ppr bndr )
FloatFlag
FltCareful
unitJoinFloat :: OutBind -> JoinFloats
unitJoinFloat :: OutBind -> JoinFloats
unitJoinFloat OutBind
bind = ASSERT(all isJoinId (bindersOf bind))
OutBind -> JoinFloats
forall a. a -> OrdList a
unitOL OutBind
bind
mkFloatBind :: SimplEnv -> OutBind -> (SimplFloats, SimplEnv)
mkFloatBind :: SimplEnv -> OutBind -> (SimplFloats, SimplEnv)
mkFloatBind SimplEnv
env OutBind
bind
= (SimplFloats
floats, SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope' })
where
floats :: SimplFloats
floats
| OutBind -> Bool
isJoinBind OutBind
bind
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
emptyLetFloats
, sfJoinFloats :: JoinFloats
sfJoinFloats = OutBind -> JoinFloats
unitJoinFloat OutBind
bind
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope' }
| Bool
otherwise
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = OutBind -> LetFloats
unitLetFloat OutBind
bind
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope' }
in_scope' :: InScopeSet
in_scope' = SimplEnv -> InScopeSet
seInScope SimplEnv
env InScopeSet -> OutBind -> InScopeSet
`extendInScopeSetBind` OutBind
bind
extendFloats :: SimplFloats -> OutBind -> SimplFloats
extendFloats :: SimplFloats -> OutBind -> SimplFloats
extendFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
floats
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jfloats
, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
OutBind
bind
| OutBind -> Bool
isJoinBind OutBind
bind
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope'
, sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats' }
| Bool
otherwise
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope'
, sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats'
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats }
where
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> OutBind -> InScopeSet
`extendInScopeSetBind` OutBind
bind
floats' :: LetFloats
floats' = LetFloats
floats LetFloats -> LetFloats -> LetFloats
`addLetFlts` OutBind -> LetFloats
unitLetFloat OutBind
bind
jfloats' :: JoinFloats
jfloats' = JoinFloats
jfloats JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` OutBind -> JoinFloats
unitJoinFloat OutBind
bind
addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
addLetFloats SimplFloats
floats let_floats :: LetFloats
let_floats@(LetFloats JoinFloats
binds FloatFlag
_)
= SimplFloats
floats { sfLetFloats :: LetFloats
sfLetFloats = SimplFloats -> LetFloats
sfLetFloats SimplFloats
floats LetFloats -> LetFloats -> LetFloats
`addLetFlts` LetFloats
let_floats
, sfInScope :: InScopeSet
sfInScope = (InScopeSet -> OutBind -> InScopeSet)
-> InScopeSet -> JoinFloats -> InScopeSet
forall b a. (b -> a -> b) -> b -> OrdList a -> b
foldlOL InScopeSet -> OutBind -> InScopeSet
extendInScopeSetBind
(SimplFloats -> InScopeSet
sfInScope SimplFloats
floats) JoinFloats
binds }
addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
addJoinFloats SimplFloats
floats JoinFloats
join_floats
= SimplFloats
floats { sfJoinFloats :: JoinFloats
sfJoinFloats = SimplFloats -> JoinFloats
sfJoinFloats SimplFloats
floats JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` JoinFloats
join_floats
, sfInScope :: InScopeSet
sfInScope = (InScopeSet -> OutBind -> InScopeSet)
-> InScopeSet -> JoinFloats -> InScopeSet
forall b a. (b -> a -> b) -> b -> OrdList a -> b
foldlOL InScopeSet -> OutBind -> InScopeSet
extendInScopeSetBind
(SimplFloats -> InScopeSet
sfInScope SimplFloats
floats) JoinFloats
join_floats }
extendInScopeSetBind :: InScopeSet -> CoreBind -> InScopeSet
extendInScopeSetBind :: InScopeSet -> OutBind -> InScopeSet
extendInScopeSetBind InScopeSet
in_scope OutBind
bind
= InScopeSet -> [InBndr] -> InScopeSet
extendInScopeSetList InScopeSet
in_scope (OutBind -> [InBndr]
forall b. Bind b -> [b]
bindersOf OutBind
bind)
addFloats :: SimplFloats -> SimplFloats -> SimplFloats
addFloats :: SimplFloats -> SimplFloats -> SimplFloats
addFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf1, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf1 })
(SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lf2, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jf2, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
= SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
lf1 LetFloats -> LetFloats -> LetFloats
`addLetFlts` LetFloats
lf2
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jf1 JoinFloats -> JoinFloats -> JoinFloats
`addJoinFlts` JoinFloats
jf2
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope }
addLetFlts :: LetFloats -> LetFloats -> LetFloats
addLetFlts :: LetFloats -> LetFloats -> LetFloats
addLetFlts (LetFloats JoinFloats
bs1 FloatFlag
l1) (LetFloats JoinFloats
bs2 FloatFlag
l2)
= JoinFloats -> FloatFlag -> LetFloats
LetFloats (JoinFloats
bs1 JoinFloats -> JoinFloats -> JoinFloats
forall a. OrdList a -> OrdList a -> OrdList a
`appOL` JoinFloats
bs2) (FloatFlag
l1 FloatFlag -> FloatFlag -> FloatFlag
`andFF` FloatFlag
l2)
letFloatBinds :: LetFloats -> [CoreBind]
letFloatBinds :: LetFloats -> [OutBind]
letFloatBinds (LetFloats JoinFloats
bs FloatFlag
_) = JoinFloats -> [OutBind]
forall a. OrdList a -> [a]
fromOL JoinFloats
bs
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts = JoinFloats -> JoinFloats -> JoinFloats
forall a. OrdList a -> OrdList a -> OrdList a
appOL
mkRecFloats :: SimplFloats -> SimplFloats
mkRecFloats :: SimplFloats -> SimplFloats
mkRecFloats floats :: SimplFloats
floats@(SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
bs FloatFlag
ff
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs
, sfInScope :: SimplFloats -> InScopeSet
sfInScope = InScopeSet
in_scope })
= ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
ASSERT2( isNilOL bs || isNilOL jbs, ppr floats )
SimplFloats :: LetFloats -> JoinFloats -> InScopeSet -> SimplFloats
SimplFloats { sfLetFloats :: LetFloats
sfLetFloats = LetFloats
floats'
, sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
jfloats'
, sfInScope :: InScopeSet
sfInScope = InScopeSet
in_scope }
where
floats' :: LetFloats
floats' | JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
bs = LetFloats
emptyLetFloats
| Bool
otherwise = OutBind -> LetFloats
unitLetFloat ([(InBndr, OutExpr)] -> OutBind
forall b. [(b, Expr b)] -> Bind b
Rec ([OutBind] -> [(InBndr, OutExpr)]
forall b. [Bind b] -> [(b, Expr b)]
flattenBinds (JoinFloats -> [OutBind]
forall a. OrdList a -> [a]
fromOL JoinFloats
bs)))
jfloats' :: JoinFloats
jfloats' | JoinFloats -> Bool
forall a. OrdList a -> Bool
isNilOL JoinFloats
jbs = JoinFloats
emptyJoinFloats
| Bool
otherwise = OutBind -> JoinFloats
unitJoinFloat ([(InBndr, OutExpr)] -> OutBind
forall b. [(b, Expr b)] -> Bind b
Rec ([OutBind] -> [(InBndr, OutExpr)]
forall b. [Bind b] -> [(b, Expr b)]
flattenBinds (JoinFloats -> [OutBind]
forall a. OrdList a -> [a]
fromOL JoinFloats
jbs)))
wrapFloats :: SimplFloats -> OutExpr -> OutExpr
wrapFloats :: SimplFloats -> OutExpr -> OutExpr
wrapFloats (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats JoinFloats
bs FloatFlag
_
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs }) OutExpr
body
= (OutBind -> OutExpr -> OutExpr) -> OutExpr -> JoinFloats -> OutExpr
forall a b. (a -> b -> b) -> b -> OrdList a -> b
foldrOL OutBind -> OutExpr -> OutExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats JoinFloats
jbs OutExpr
body) JoinFloats
bs
wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
wrapJoinFloatsX SimplFloats
floats OutExpr
body
= ( SimplFloats
floats { sfJoinFloats :: JoinFloats
sfJoinFloats = JoinFloats
emptyJoinFloats }
, JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats (SimplFloats -> JoinFloats
sfJoinFloats SimplFloats
floats) OutExpr
body )
wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats JoinFloats
join_floats OutExpr
body
= (OutBind -> OutExpr -> OutExpr) -> OutExpr -> JoinFloats -> OutExpr
forall a b. (a -> b -> b) -> b -> OrdList a -> b
foldrOL OutBind -> OutExpr -> OutExpr
forall b. Bind b -> Expr b -> Expr b
Let OutExpr
body JoinFloats
join_floats
getTopFloatBinds :: SimplFloats -> [CoreBind]
getTopFloatBinds :: SimplFloats -> [OutBind]
getTopFloatBinds (SimplFloats { sfLetFloats :: SimplFloats -> LetFloats
sfLetFloats = LetFloats
lbs
, sfJoinFloats :: SimplFloats -> JoinFloats
sfJoinFloats = JoinFloats
jbs})
= ASSERT( isNilOL jbs )
LetFloats -> [OutBind]
letFloatBinds LetFloats
lbs
mapLetFloats :: LetFloats -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> LetFloats
mapLetFloats :: LetFloats -> ((InBndr, OutExpr) -> (InBndr, OutExpr)) -> LetFloats
mapLetFloats (LetFloats JoinFloats
fs FloatFlag
ff) (InBndr, OutExpr) -> (InBndr, OutExpr)
fun
= JoinFloats -> FloatFlag -> LetFloats
LetFloats ((OutBind -> OutBind) -> JoinFloats -> JoinFloats
forall a b. (a -> b) -> OrdList a -> OrdList b
mapOL OutBind -> OutBind
app JoinFloats
fs) FloatFlag
ff
where
app :: OutBind -> OutBind
app (NonRec InBndr
b OutExpr
e) = case (InBndr, OutExpr) -> (InBndr, OutExpr)
fun (InBndr
b,OutExpr
e) of (InBndr
b',OutExpr
e') -> InBndr -> OutExpr -> OutBind
forall b. b -> Expr b -> Bind b
NonRec InBndr
b' OutExpr
e'
app (Rec [(InBndr, OutExpr)]
bs) = [(InBndr, OutExpr)] -> OutBind
forall b. [(b, Expr b)] -> Bind b
Rec (((InBndr, OutExpr) -> (InBndr, OutExpr))
-> [(InBndr, OutExpr)] -> [(InBndr, OutExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (InBndr, OutExpr) -> (InBndr, OutExpr)
fun [(InBndr, OutExpr)]
bs)
substId :: SimplEnv -> InId -> SimplSR
substId :: SimplEnv -> InBndr -> SimplSR
substId (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) InBndr
v
= case SimplIdSubst -> InBndr -> Maybe SimplSR
forall a. VarEnv a -> InBndr -> Maybe a
lookupVarEnv SimplIdSubst
ids InBndr
v of
Maybe SimplSR
Nothing -> InBndr -> SimplSR
DoneId (InScopeSet -> InBndr -> InBndr
refineFromInScope InScopeSet
in_scope InBndr
v)
Just (DoneId InBndr
v) -> InBndr -> SimplSR
DoneId (InScopeSet -> InBndr -> InBndr
refineFromInScope InScopeSet
in_scope InBndr
v)
Just SimplSR
res -> SimplSR
res
refineFromInScope :: InScopeSet -> Var -> Var
refineFromInScope :: InScopeSet -> InBndr -> InBndr
refineFromInScope InScopeSet
in_scope InBndr
v
| InBndr -> Bool
isLocalId InBndr
v = case InScopeSet -> InBndr -> Maybe InBndr
lookupInScope InScopeSet
in_scope InBndr
v of
Just InBndr
v' -> InBndr
v'
Maybe InBndr
Nothing -> WARN( True, ppr v ) v
| Bool
otherwise = InBndr
v
lookupRecBndr :: SimplEnv -> InId -> OutId
lookupRecBndr :: SimplEnv -> InBndr -> InBndr
lookupRecBndr (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
ids }) InBndr
v
= case SimplIdSubst -> InBndr -> Maybe SimplSR
forall a. VarEnv a -> InBndr -> Maybe a
lookupVarEnv SimplIdSubst
ids InBndr
v of
Just (DoneId InBndr
v) -> InBndr
v
Just SimplSR
_ -> String -> SDoc -> InBndr
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"lookupRecBndr" (InBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr InBndr
v)
Maybe SimplSR
Nothing -> InScopeSet -> InBndr -> InBndr
refineFromInScope InScopeSet
in_scope InBndr
v
simplBinders :: SimplEnv -> [InBndr] -> SimplM (SimplEnv, [OutBndr])
simplBinders :: SimplEnv -> [InBndr] -> SimplM (SimplEnv, [InBndr])
simplBinders SimplEnv
env [InBndr]
bndrs = (SimplEnv -> InBndr -> SimplM (SimplEnv, InBndr))
-> SimplEnv -> [InBndr] -> SimplM (SimplEnv, [InBndr])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SimplEnv -> InBndr -> SimplM (SimplEnv, InBndr)
simplBinder SimplEnv
env [InBndr]
bndrs
simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, InBndr)
simplBinder SimplEnv
env InBndr
bndr
| InBndr -> Bool
isTyVar InBndr
bndr = do { let (SimplEnv
env', InBndr
tv) = SimplEnv -> InBndr -> (SimplEnv, InBndr)
substTyVarBndr SimplEnv
env InBndr
bndr
; InBndr -> ()
seqTyVar InBndr
tv () -> SimplM (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
`seq` (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env', InBndr
tv) }
| Bool
otherwise = do { let (SimplEnv
env', InBndr
id) = SimplEnv -> InBndr -> (SimplEnv, InBndr)
substIdBndr SimplEnv
env InBndr
bndr
; InBndr -> ()
seqId InBndr
id () -> SimplM (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
`seq` (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env', InBndr
id) }
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, InBndr)
simplNonRecBndr SimplEnv
env InBndr
id
= do { let (SimplEnv
env1, InBndr
id1) = SimplEnv -> InBndr -> (SimplEnv, InBndr)
substIdBndr SimplEnv
env InBndr
id
; InBndr -> ()
seqId InBndr
id1 () -> SimplM (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
`seq` (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env1, InBndr
id1) }
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
simplRecBndrs env :: SimplEnv
env@(SimplEnv {}) [InBndr]
ids
= ASSERT(all (not . isJoinId) ids)
do { let (SimplEnv
env1, [InBndr]
ids1) = (SimplEnv -> InBndr -> (SimplEnv, InBndr))
-> SimplEnv -> [InBndr] -> (SimplEnv, [InBndr])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL SimplEnv -> InBndr -> (SimplEnv, InBndr)
substIdBndr SimplEnv
env [InBndr]
ids
; [InBndr] -> ()
seqIds [InBndr]
ids1 () -> SimplM SimplEnv -> SimplM SimplEnv
`seq` SimplEnv -> SimplM SimplEnv
forall (m :: * -> *) a. Monad m => a -> m a
return SimplEnv
env1 }
substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, InBndr)
substIdBndr SimplEnv
env InBndr
bndr
| InBndr -> Bool
isCoVar InBndr
bndr = SimplEnv -> InBndr -> (SimplEnv, InBndr)
substCoVarBndr SimplEnv
env InBndr
bndr
| Bool
otherwise = SimplEnv -> InBndr -> (SimplEnv, InBndr)
substNonCoVarIdBndr SimplEnv
env InBndr
bndr
substNonCoVarIdBndr
:: SimplEnv
-> InBndr
-> (SimplEnv, OutBndr)
substNonCoVarIdBndr :: SimplEnv -> InBndr -> (SimplEnv, InBndr)
substNonCoVarIdBndr SimplEnv
env InBndr
id = SimplEnv -> InBndr -> (InBndr -> InBndr) -> (SimplEnv, InBndr)
subst_id_bndr SimplEnv
env InBndr
id (\InBndr
x -> InBndr
x)
subst_id_bndr :: SimplEnv
-> InBndr
-> (OutId -> OutId)
-> (SimplEnv, OutBndr)
subst_id_bndr :: SimplEnv -> InBndr -> (InBndr -> InBndr) -> (SimplEnv, InBndr)
subst_id_bndr env :: SimplEnv
env@(SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seIdSubst :: SimplEnv -> SimplIdSubst
seIdSubst = SimplIdSubst
id_subst })
InBndr
old_id InBndr -> InBndr
adjust_type
= ASSERT2( not (isCoVar old_id), ppr old_id )
(SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope InScopeSet -> InBndr -> InScopeSet
`extendInScopeSet` InBndr
new_id,
seIdSubst :: SimplIdSubst
seIdSubst = SimplIdSubst
new_subst }, InBndr
new_id)
where
id1 :: InBndr
id1 = InScopeSet -> InBndr -> InBndr
uniqAway InScopeSet
in_scope InBndr
old_id
id2 :: InBndr
id2 = SimplEnv -> InBndr -> InBndr
substIdType SimplEnv
env InBndr
id1
id3 :: InBndr
id3 = InBndr -> InBndr
zapFragileIdInfo InBndr
id2
new_id :: InBndr
new_id = InBndr -> InBndr
adjust_type InBndr
id3
new_subst :: SimplIdSubst
new_subst | InBndr
new_id InBndr -> InBndr -> Bool
forall a. Eq a => a -> a -> Bool
/= InBndr
old_id
= SimplIdSubst -> InBndr -> SimplSR -> SimplIdSubst
forall a. VarEnv a -> InBndr -> a -> VarEnv a
extendVarEnv SimplIdSubst
id_subst InBndr
old_id (InBndr -> SimplSR
DoneId InBndr
new_id)
| Bool
otherwise
= SimplIdSubst -> InBndr -> SimplIdSubst
forall a. VarEnv a -> InBndr -> VarEnv a
delVarEnv SimplIdSubst
id_subst InBndr
old_id
seqTyVar :: TyVar -> ()
seqTyVar :: InBndr -> ()
seqTyVar InBndr
b = InBndr
b InBndr -> () -> ()
`seq` ()
seqId :: Id -> ()
seqId :: InBndr -> ()
seqId InBndr
id = Type -> ()
seqType (InBndr -> Type
idType InBndr
id) () -> () -> ()
`seq`
HasDebugCallStack => InBndr -> IdInfo
InBndr -> IdInfo
idInfo InBndr
id IdInfo -> () -> ()
`seq`
()
seqIds :: [Id] -> ()
seqIds :: [InBndr] -> ()
seqIds [] = ()
seqIds (InBndr
id:[InBndr]
ids) = InBndr -> ()
seqId InBndr
id () -> () -> ()
`seq` [InBndr] -> ()
seqIds [InBndr]
ids
simplNonRecJoinBndr :: SimplEnv -> InBndr
-> Mult -> OutType
-> SimplM (SimplEnv, OutBndr)
simplNonRecJoinBndr :: SimplEnv -> InBndr -> Type -> Type -> SimplM (SimplEnv, InBndr)
simplNonRecJoinBndr SimplEnv
env InBndr
id Type
mult Type
res_ty
= do { let (SimplEnv
env1, InBndr
id1) = Type -> Type -> SimplEnv -> InBndr -> (SimplEnv, InBndr)
simplJoinBndr Type
mult Type
res_ty SimplEnv
env InBndr
id
; InBndr -> ()
seqId InBndr
id1 () -> SimplM (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
`seq` (SimplEnv, InBndr) -> SimplM (SimplEnv, InBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (SimplEnv
env1, InBndr
id1) }
simplRecJoinBndrs :: SimplEnv -> [InBndr]
-> Mult -> OutType
-> SimplM SimplEnv
simplRecJoinBndrs :: SimplEnv -> [InBndr] -> Type -> Type -> SimplM SimplEnv
simplRecJoinBndrs env :: SimplEnv
env@(SimplEnv {}) [InBndr]
ids Type
mult Type
res_ty
= ASSERT(all isJoinId ids)
do { let (SimplEnv
env1, [InBndr]
ids1) = (SimplEnv -> InBndr -> (SimplEnv, InBndr))
-> SimplEnv -> [InBndr] -> (SimplEnv, [InBndr])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (Type -> Type -> SimplEnv -> InBndr -> (SimplEnv, InBndr)
simplJoinBndr Type
mult Type
res_ty) SimplEnv
env [InBndr]
ids
; [InBndr] -> ()
seqIds [InBndr]
ids1 () -> SimplM SimplEnv -> SimplM SimplEnv
`seq` SimplEnv -> SimplM SimplEnv
forall (m :: * -> *) a. Monad m => a -> m a
return SimplEnv
env1 }
simplJoinBndr :: Mult -> OutType
-> SimplEnv -> InBndr
-> (SimplEnv, OutBndr)
simplJoinBndr :: Type -> Type -> SimplEnv -> InBndr -> (SimplEnv, InBndr)
simplJoinBndr Type
mult Type
res_ty SimplEnv
env InBndr
id
= SimplEnv -> InBndr -> (InBndr -> InBndr) -> (SimplEnv, InBndr)
subst_id_bndr SimplEnv
env InBndr
id (Type -> Type -> InBndr -> InBndr
adjustJoinPointType Type
mult Type
res_ty)
adjustJoinPointType :: Mult
-> Type
-> Id
-> Id
adjustJoinPointType :: Type -> Type -> InBndr -> InBndr
adjustJoinPointType Type
mult Type
new_res_ty InBndr
join_id
= ASSERT( isJoinId join_id )
InBndr -> Type -> InBndr
setIdType InBndr
join_id Type
new_join_ty
where
orig_ar :: JoinArity
orig_ar = InBndr -> JoinArity
idJoinArity InBndr
join_id
orig_ty :: Type
orig_ty = InBndr -> Type
idType InBndr
join_id
new_join_ty :: Type
new_join_ty = JoinArity -> Type -> Type
forall {t}. (Eq t, Num t) => t -> Type -> Type
go JoinArity
orig_ar Type
orig_ty
go :: t -> Type -> Type
go t
0 Type
_ = Type
new_res_ty
go t
n Type
ty | Just (TyCoBinder
arg_bndr, Type
res_ty) <- Type -> Maybe (TyCoBinder, Type)
splitPiTy_maybe Type
ty
= TyCoBinder -> Type -> Type
mkPiTy (TyCoBinder -> TyCoBinder
scale_bndr TyCoBinder
arg_bndr) (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$
t -> Type -> Type
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Type
res_ty
| Bool
otherwise
= String -> SDoc -> Type
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"adjustJoinPointType" (JoinArity -> SDoc
forall a. Outputable a => a -> SDoc
ppr JoinArity
orig_ar SDoc -> SDoc -> SDoc
<+> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
orig_ty)
scale_bndr :: TyCoBinder -> TyCoBinder
scale_bndr (Anon AnonArgFlag
af Scaled Type
t) = AnonArgFlag -> Scaled Type -> TyCoBinder
Anon AnonArgFlag
af (Type -> Scaled Type -> Scaled Type
forall a. Type -> Scaled a -> Scaled a
scaleScaled Type
mult Scaled Type
t)
scale_bndr b :: TyCoBinder
b@(Named TyCoVarBinder
_) = TyCoBinder
b
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tv_env
, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cv_env })
= InScopeSet -> (TvSubstEnv, CvSubstEnv) -> TCvSubst
mkTCvSubst InScopeSet
in_scope (TvSubstEnv
tv_env, CvSubstEnv
cv_env)
substTy :: SimplEnv -> Type -> Type
substTy :: SimplEnv -> Type -> Type
substTy SimplEnv
env Type
ty = HasCallStack => TCvSubst -> Type -> Type
TCvSubst -> Type -> Type
Type.substTy (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) Type
ty
substTyVar :: SimplEnv -> TyVar -> Type
substTyVar :: SimplEnv -> InBndr -> Type
substTyVar SimplEnv
env InBndr
tv = TCvSubst -> InBndr -> Type
Type.substTyVar (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) InBndr
tv
substTyVarBndr :: SimplEnv -> TyVar -> (SimplEnv, TyVar)
substTyVarBndr :: SimplEnv -> InBndr -> (SimplEnv, InBndr)
substTyVarBndr SimplEnv
env InBndr
tv
= case HasCallStack => TCvSubst -> InBndr -> (TCvSubst, InBndr)
TCvSubst -> InBndr -> (TCvSubst, InBndr)
Type.substTyVarBndr (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) InBndr
tv of
(TCvSubst InScopeSet
in_scope' TvSubstEnv
tv_env' CvSubstEnv
cv_env', InBndr
tv')
-> (SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope', seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tv_env', seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cv_env' }, InBndr
tv')
substCoVar :: SimplEnv -> CoVar -> Coercion
substCoVar :: SimplEnv -> InBndr -> Coercion
substCoVar SimplEnv
env InBndr
tv = TCvSubst -> InBndr -> Coercion
Coercion.substCoVar (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) InBndr
tv
substCoVarBndr :: SimplEnv -> CoVar -> (SimplEnv, CoVar)
substCoVarBndr :: SimplEnv -> InBndr -> (SimplEnv, InBndr)
substCoVarBndr SimplEnv
env InBndr
cv
= case HasCallStack => TCvSubst -> InBndr -> (TCvSubst, InBndr)
TCvSubst -> InBndr -> (TCvSubst, InBndr)
Coercion.substCoVarBndr (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) InBndr
cv of
(TCvSubst InScopeSet
in_scope' TvSubstEnv
tv_env' CvSubstEnv
cv_env', InBndr
cv')
-> (SimplEnv
env { seInScope :: InScopeSet
seInScope = InScopeSet
in_scope', seTvSubst :: TvSubstEnv
seTvSubst = TvSubstEnv
tv_env', seCvSubst :: CvSubstEnv
seCvSubst = CvSubstEnv
cv_env' }, InBndr
cv')
substCo :: SimplEnv -> Coercion -> Coercion
substCo :: SimplEnv -> Coercion -> Coercion
substCo SimplEnv
env Coercion
co = HasCallStack => TCvSubst -> Coercion -> Coercion
TCvSubst -> Coercion -> Coercion
Coercion.substCo (SimplEnv -> TCvSubst
getTCvSubst SimplEnv
env) Coercion
co
substIdType :: SimplEnv -> Id -> Id
substIdType :: SimplEnv -> InBndr -> InBndr
substIdType (SimplEnv { seInScope :: SimplEnv -> InScopeSet
seInScope = InScopeSet
in_scope, seTvSubst :: SimplEnv -> TvSubstEnv
seTvSubst = TvSubstEnv
tv_env, seCvSubst :: SimplEnv -> CvSubstEnv
seCvSubst = CvSubstEnv
cv_env }) InBndr
id
| (TvSubstEnv -> Bool
forall a. VarEnv a -> Bool
isEmptyVarEnv TvSubstEnv
tv_env Bool -> Bool -> Bool
&& CvSubstEnv -> Bool
forall a. VarEnv a -> Bool
isEmptyVarEnv CvSubstEnv
cv_env)
Bool -> Bool -> Bool
|| Bool
no_free_vars
= InBndr
id
| Bool
otherwise = (Type -> Type) -> InBndr -> InBndr
Id.updateIdTypeAndMult (TCvSubst -> Type -> Type
Type.substTyUnchecked TCvSubst
subst) InBndr
id
where
no_free_vars :: Bool
no_free_vars = Type -> Bool
noFreeVarsOfType Type
old_ty Bool -> Bool -> Bool
&& Type -> Bool
noFreeVarsOfType Type
old_w
subst :: TCvSubst
subst = InScopeSet -> TvSubstEnv -> CvSubstEnv -> TCvSubst
TCvSubst InScopeSet
in_scope TvSubstEnv
tv_env CvSubstEnv
cv_env
old_ty :: Type
old_ty = InBndr -> Type
idType InBndr
id
old_w :: Type
old_w = InBndr -> Type
varMult InBndr
id