{-# LANGUAGE CPP #-}
module SimplEnv (
setMode, getMode, updMode, seDynFlags,
SimplEnv(..), pprSimplEnv,
mkSimplEnv, extendIdSubst,
SimplEnv.extendTvSubst, SimplEnv.extendCvSubst,
zapSubstEnv, setSubstEnv,
getInScope, setInScopeFromE, setInScopeFromF,
setInScopeSet, modifyInScope, addNewInScopeIds,
getSimplRules,
SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
simplBinder, simplBinders,
substTy, substTyVar, getTCvSubst,
substCo, substCoVar,
SimplFloats(..), emptyFloats, mkRecFloats,
mkFloatBind, addLetFloats, addJoinFloats, addFloats,
extendFloats, wrapFloats,
doFloatFromRhs, getTopFloatBinds,
LetFloats, letFloatBinds, emptyLetFloats, unitLetFloat,
addLetFlts, mapLetFloats,
JoinFloat, JoinFloats, emptyJoinFloats,
wrapJoinFloats, wrapJoinFloatsX, unitJoinFloat, addJoinFlts
) where
#include "HsVersions.h"
import GhcPrelude
import SimplMonad
import CoreMonad ( SimplMode(..) )
import CoreSyn
import CoreUtils
import Var
import VarEnv
import VarSet
import OrdList
import Id
import MkCore ( mkWildValBinder )
import DynFlags ( DynFlags )
import TysWiredIn
import qualified Type
import Type hiding ( substTy, substTyVar, substTyVarBndr )
import qualified Coercion
import Coercion hiding ( substCo, substCoVar, substCoVarBndr )
import BasicTypes
import MonadUtils
import Outputable
import Util
import UniqFM ( pprUniqFM )
import Data.List
data SimplEnv
= SimplEnv {
seMode :: SimplMode
, seTvSubst :: TvSubstEnv
, seCvSubst :: CvSubstEnv
, seIdSubst :: SimplIdSubst
, seInScope :: InScopeSet
}
data SimplFloats
= SimplFloats
{
sfLetFloats :: LetFloats
, sfJoinFloats :: JoinFloats
, sfInScope :: InScopeSet
}
instance Outputable SimplFloats where
ppr (SimplFloats { sfLetFloats = lf, sfJoinFloats = jf, sfInScope = is })
= text "SimplFloats"
<+> braces (vcat [ text "lets: " <+> ppr lf
, text "joins:" <+> ppr jf
, text "in_scope:" <+> ppr is ])
emptyFloats :: SimplEnv -> SimplFloats
emptyFloats env
= SimplFloats { sfLetFloats = emptyLetFloats
, sfJoinFloats = emptyJoinFloats
, sfInScope = seInScope env }
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv env
= vcat [text "TvSubst:" <+> ppr (seTvSubst env),
text "CvSubst:" <+> ppr (seCvSubst env),
text "IdSubst:" <+> id_subst_doc,
text "InScope:" <+> in_scope_vars_doc
]
where
id_subst_doc = pprUniqFM ppr (seIdSubst env)
in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
(vcat . map ppr_one)
ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
| otherwise = ppr v
type SimplIdSubst = IdEnv SimplSR
data SimplSR
= DoneEx OutExpr (Maybe JoinArity)
| DoneId OutId
| ContEx TvSubstEnv
CvSubstEnv
SimplIdSubst
InExpr
instance Outputable SimplSR where
ppr (DoneId v) = text "DoneId" <+> ppr v
ppr (DoneEx e mj) = text "DoneEx" <> pp_mj <+> ppr e
where
pp_mj = case mj of
Nothing -> empty
Just n -> parens (int n)
ppr (ContEx _tv _cv _id e) = vcat [text "ContEx" <+> ppr e ]
mkSimplEnv :: SimplMode -> SimplEnv
mkSimplEnv mode
= SimplEnv { seMode = mode
, seInScope = init_in_scope
, seTvSubst = emptyVarEnv
, seCvSubst = emptyVarEnv
, seIdSubst = emptyVarEnv }
init_in_scope :: InScopeSet
init_in_scope = mkInScopeSet (unitVarSet (mkWildValBinder unitTy))
getMode :: SimplEnv -> SimplMode
getMode env = seMode env
seDynFlags :: SimplEnv -> DynFlags
seDynFlags env = sm_dflags (seMode env)
setMode :: SimplMode -> SimplEnv -> SimplEnv
setMode mode env = env { seMode = mode }
updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
updMode upd env = env { seMode = upd (seMode env) }
extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
= ASSERT2( isId var && not (isCoVar var), ppr var )
env { seIdSubst = extendVarEnv subst var res }
extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
= ASSERT( isTyVar var )
env {seTvSubst = extendVarEnv tsubst var res}
extendCvSubst :: SimplEnv -> CoVar -> Coercion -> SimplEnv
extendCvSubst env@(SimplEnv {seCvSubst = csubst}) var co
= ASSERT( isCoVar var )
env {seCvSubst = extendVarEnv csubst var co}
getInScope :: SimplEnv -> InScopeSet
getInScope env = seInScope env
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet env in_scope = env {seInScope = in_scope}
setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeFromE rhs_env here_env = rhs_env { seInScope = seInScope here_env }
setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
setInScopeFromF env floats = env { seInScope = sfInScope floats }
addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
addNewInScopeIds env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }) vs
= env { seInScope = in_scope `extendInScopeSetList` vs,
seIdSubst = id_subst `delVarEnvList` vs }
modifyInScope :: SimplEnv -> CoreBndr -> SimplEnv
modifyInScope env@(SimplEnv {seInScope = in_scope}) v
= env {seInScope = extendInScopeSet in_scope v}
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv env = env {seTvSubst = emptyVarEnv, seCvSubst = emptyVarEnv, seIdSubst = emptyVarEnv}
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv env tvs cvs ids = env { seTvSubst = tvs, seCvSubst = cvs, seIdSubst = ids }
mkContEx :: SimplEnv -> InExpr -> SimplSR
mkContEx (SimplEnv { seTvSubst = tvs, seCvSubst = cvs, seIdSubst = ids }) e = ContEx tvs cvs ids e
data LetFloats = LetFloats (OrdList OutBind) FloatFlag
type JoinFloat = OutBind
type JoinFloats = OrdList JoinFloat
data FloatFlag
= FltLifted
| FltOkSpec
| FltCareful
instance Outputable LetFloats where
ppr (LetFloats binds ff) = ppr ff $$ ppr (fromOL binds)
instance Outputable FloatFlag where
ppr FltLifted = text "FltLifted"
ppr FltOkSpec = text "FltOkSpec"
ppr FltCareful = text "FltCareful"
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF FltCareful _ = FltCareful
andFF FltOkSpec FltCareful = FltCareful
andFF FltOkSpec _ = FltOkSpec
andFF FltLifted flt = flt
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
doFloatFromRhs lvl rec str (SimplFloats { sfLetFloats = LetFloats fs ff }) rhs
= not (isNilOL fs) && want_to_float && can_float
where
want_to_float = isTopLevel lvl || exprIsCheap rhs || exprIsExpandable rhs
can_float = case ff of
FltLifted -> True
FltOkSpec -> isNotTopLevel lvl && isNonRec rec
FltCareful -> isNotTopLevel lvl && isNonRec rec && str
emptyLetFloats :: LetFloats
emptyLetFloats = LetFloats nilOL FltLifted
emptyJoinFloats :: JoinFloats
emptyJoinFloats = nilOL
unitLetFloat :: OutBind -> LetFloats
unitLetFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
LetFloats (unitOL bind) (flag bind)
where
flag (Rec {}) = FltLifted
flag (NonRec bndr rhs)
| not (isStrictId bndr) = FltLifted
| exprIsTickedString rhs = FltLifted
| exprOkForSpeculation rhs = FltOkSpec
| otherwise = ASSERT2( not (isUnliftedType (idType bndr)), ppr bndr )
FltCareful
unitJoinFloat :: OutBind -> JoinFloats
unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
unitOL bind
mkFloatBind :: SimplEnv -> OutBind -> (SimplFloats, SimplEnv)
mkFloatBind env bind
= (floats, env { seInScope = in_scope' })
where
floats
| isJoinBind bind
= SimplFloats { sfLetFloats = emptyLetFloats
, sfJoinFloats = unitJoinFloat bind
, sfInScope = in_scope' }
| otherwise
= SimplFloats { sfLetFloats = unitLetFloat bind
, sfJoinFloats = emptyJoinFloats
, sfInScope = in_scope' }
in_scope' = seInScope env `extendInScopeSetBind` bind
extendFloats :: SimplFloats -> OutBind -> SimplFloats
extendFloats (SimplFloats { sfLetFloats = floats
, sfJoinFloats = jfloats
, sfInScope = in_scope })
bind
| isJoinBind bind
= SimplFloats { sfInScope = in_scope'
, sfLetFloats = floats
, sfJoinFloats = jfloats' }
| otherwise
= SimplFloats { sfInScope = in_scope'
, sfLetFloats = floats'
, sfJoinFloats = jfloats }
where
in_scope' = in_scope `extendInScopeSetBind` bind
floats' = floats `addLetFlts` unitLetFloat bind
jfloats' = jfloats `addJoinFlts` unitJoinFloat bind
addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
addLetFloats floats let_floats@(LetFloats binds _)
= floats { sfLetFloats = sfLetFloats floats `addLetFlts` let_floats
, sfInScope = foldlOL extendInScopeSetBind
(sfInScope floats) binds }
addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
addJoinFloats floats join_floats
= floats { sfJoinFloats = sfJoinFloats floats `addJoinFlts` join_floats
, sfInScope = foldlOL extendInScopeSetBind
(sfInScope floats) join_floats }
extendInScopeSetBind :: InScopeSet -> CoreBind -> InScopeSet
extendInScopeSetBind in_scope bind
= extendInScopeSetList in_scope (bindersOf bind)
addFloats :: SimplFloats -> SimplFloats -> SimplFloats
addFloats (SimplFloats { sfLetFloats = lf1, sfJoinFloats = jf1 })
(SimplFloats { sfLetFloats = lf2, sfJoinFloats = jf2, sfInScope = in_scope })
= SimplFloats { sfLetFloats = lf1 `addLetFlts` lf2
, sfJoinFloats = jf1 `addJoinFlts` jf2
, sfInScope = in_scope }
addLetFlts :: LetFloats -> LetFloats -> LetFloats
addLetFlts (LetFloats bs1 l1) (LetFloats bs2 l2)
= LetFloats (bs1 `appOL` bs2) (l1 `andFF` l2)
letFloatBinds :: LetFloats -> [CoreBind]
letFloatBinds (LetFloats bs _) = fromOL bs
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts = appOL
mkRecFloats :: SimplFloats -> SimplFloats
mkRecFloats floats@(SimplFloats { sfLetFloats = LetFloats bs ff
, sfJoinFloats = jbs
, sfInScope = in_scope })
= ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
ASSERT2( isNilOL bs || isNilOL jbs, ppr floats )
SimplFloats { sfLetFloats = floats'
, sfJoinFloats = jfloats'
, sfInScope = in_scope }
where
floats' | isNilOL bs = emptyLetFloats
| otherwise = unitLetFloat (Rec (flattenBinds (fromOL bs)))
jfloats' | isNilOL jbs = emptyJoinFloats
| otherwise = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
wrapFloats :: SimplFloats -> OutExpr -> OutExpr
wrapFloats (SimplFloats { sfLetFloats = LetFloats bs _
, sfJoinFloats = jbs }) body
= foldrOL Let (wrapJoinFloats jbs body) bs
wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
wrapJoinFloatsX floats body
= ( floats { sfJoinFloats = emptyJoinFloats }
, wrapJoinFloats (sfJoinFloats floats) body )
wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
wrapJoinFloats join_floats body
= foldrOL Let body join_floats
getTopFloatBinds :: SimplFloats -> [CoreBind]
getTopFloatBinds (SimplFloats { sfLetFloats = lbs
, sfJoinFloats = jbs})
= ASSERT( isNilOL jbs )
letFloatBinds lbs
mapLetFloats :: LetFloats -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> LetFloats
mapLetFloats (LetFloats fs ff) fun
= LetFloats (mapOL app fs) ff
where
app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
app (Rec bs) = Rec (map fun bs)
substId :: SimplEnv -> InId -> SimplSR
substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
= case lookupVarEnv ids v of
Nothing -> DoneId (refineFromInScope in_scope v)
Just (DoneId v) -> DoneId (refineFromInScope in_scope v)
Just res -> res
refineFromInScope :: InScopeSet -> Var -> Var
refineFromInScope in_scope v
| isLocalId v = case lookupInScope in_scope v of
Just v' -> v'
Nothing -> WARN( True, ppr v ) v
| otherwise = v
lookupRecBndr :: SimplEnv -> InId -> OutId
lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
= case lookupVarEnv ids v of
Just (DoneId v) -> v
Just _ -> pprPanic "lookupRecBndr" (ppr v)
Nothing -> refineFromInScope in_scope v
simplBinders :: SimplEnv -> [InBndr] -> SimplM (SimplEnv, [OutBndr])
simplBinders env bndrs = mapAccumLM simplBinder env bndrs
simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder env bndr
| isTyVar bndr = do { let (env', tv) = substTyVarBndr env bndr
; seqTyVar tv `seq` return (env', tv) }
| otherwise = do { let (env', id) = substIdBndr Nothing env bndr
; seqId id `seq` return (env', id) }
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplNonRecBndr env id
= do { let (env1, id1) = substIdBndr Nothing env id
; seqId id1 `seq` return (env1, id1) }
simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
-> SimplM (SimplEnv, OutBndr)
simplNonRecJoinBndr env res_ty id
= do { let (env1, id1) = substIdBndr (Just res_ty) env id
; seqId id1 `seq` return (env1, id1) }
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
simplRecBndrs env@(SimplEnv {}) ids
= ASSERT(all (not . isJoinId) ids)
do { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
; seqIds ids1 `seq` return env1 }
simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
= ASSERT(all isJoinId ids)
do { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
; seqIds ids1 `seq` return env1 }
substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
substIdBndr new_res_ty env bndr
| isCoVar bndr = substCoVarBndr env bndr
| otherwise = substNonCoVarIdBndr new_res_ty env bndr
substNonCoVarIdBndr
:: Maybe OutType
-> SimplEnv
-> InBndr
-> (SimplEnv, OutBndr)
substNonCoVarIdBndr new_res_ty
env@(SimplEnv { seInScope = in_scope
, seIdSubst = id_subst })
old_id
= ASSERT2( not (isCoVar old_id), ppr old_id )
(env { seInScope = in_scope `extendInScopeSet` new_id,
seIdSubst = new_subst }, new_id)
where
id1 = uniqAway in_scope old_id
id2 = substIdType env id1
id3 | Just res_ty <- new_res_ty
= id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
| otherwise
= id2
new_id = zapFragileIdInfo id3
new_subst | new_id /= old_id
= extendVarEnv id_subst old_id (DoneId new_id)
| otherwise
= delVarEnv id_subst old_id
seqTyVar :: TyVar -> ()
seqTyVar b = b `seq` ()
seqId :: Id -> ()
seqId id = seqType (idType id) `seq`
idInfo id `seq`
()
seqIds :: [Id] -> ()
seqIds [] = ()
seqIds (id:ids) = seqId id `seq` seqIds ids
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
, seCvSubst = cv_env })
= mkTCvSubst in_scope (tv_env, cv_env)
substTy :: SimplEnv -> Type -> Type
substTy env ty = Type.substTy (getTCvSubst env) ty
substTyVar :: SimplEnv -> TyVar -> Type
substTyVar env tv = Type.substTyVar (getTCvSubst env) tv
substTyVarBndr :: SimplEnv -> TyVar -> (SimplEnv, TyVar)
substTyVarBndr env tv
= case Type.substTyVarBndr (getTCvSubst env) tv of
(TCvSubst in_scope' tv_env' cv_env', tv')
-> (env { seInScope = in_scope', seTvSubst = tv_env', seCvSubst = cv_env' }, tv')
substCoVar :: SimplEnv -> CoVar -> Coercion
substCoVar env tv = Coercion.substCoVar (getTCvSubst env) tv
substCoVarBndr :: SimplEnv -> CoVar -> (SimplEnv, CoVar)
substCoVarBndr env cv
= case Coercion.substCoVarBndr (getTCvSubst env) cv of
(TCvSubst in_scope' tv_env' cv_env', cv')
-> (env { seInScope = in_scope', seTvSubst = tv_env', seCvSubst = cv_env' }, cv')
substCo :: SimplEnv -> Coercion -> Coercion
substCo env co = Coercion.substCo (getTCvSubst env) co
substIdType :: SimplEnv -> Id -> Id
substIdType (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env }) id
| (isEmptyVarEnv tv_env && isEmptyVarEnv cv_env)
|| noFreeVarsOfType old_ty
= id
| otherwise = Id.setIdType id (Type.substTy (TCvSubst in_scope tv_env cv_env) old_ty)
where
old_ty = idType id