module SimplEnv (
setMode, getMode, updMode,
SimplEnv(..), StaticEnv, pprSimplEnv,
mkSimplEnv, extendIdSubst,
SimplEnv.extendTvSubst, SimplEnv.extendCvSubst,
zapSubstEnv, setSubstEnv,
getInScope, setInScopeAndZapFloats,
setInScopeSet, modifyInScope, addNewInScopeIds,
getSimplRules,
SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
isJoinIdInEnv_maybe,
simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
simplBinder, simplBinders,
substTy, substTyVar, getTCvSubst,
substCo, substCoVar,
Floats, emptyFloats, isEmptyFloats,
addNonRec, addFloats, extendFloats,
wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
doFloatFromRhs, getFloatBinds,
JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
wrapJoinFloats, zapJoinFloats, restoreJoinFloats, getJoinFloatBinds,
) where
#include "HsVersions.h"
import SimplMonad
import CoreMonad ( SimplifierMode(..) )
import CoreSyn
import CoreUtils
import Var
import VarEnv
import VarSet
import OrdList
import Id
import MkCore ( mkWildValBinder )
import TysWiredIn
import qualified Type
import Type hiding ( substTy, substTyVar, substTyVarBndr )
import qualified Coercion
import Coercion hiding ( substCo, substCoVar, substCoVarBndr )
import BasicTypes
import MonadUtils
import Outputable
import Util
import UniqFM ( pprUniqFM )
import Data.List
data SimplEnv
= SimplEnv {
seMode :: SimplifierMode,
seTvSubst :: TvSubstEnv,
seCvSubst :: CvSubstEnv,
seIdSubst :: SimplIdSubst,
seInScope :: InScopeSet,
seFloats :: Floats,
seJoinFloats :: JoinFloats
}
type StaticEnv = SimplEnv
pprSimplEnv :: SimplEnv -> SDoc
pprSimplEnv env
= vcat [text "TvSubst:" <+> ppr (seTvSubst env),
text "CvSubst:" <+> ppr (seCvSubst env),
text "IdSubst:" <+> id_subst_doc,
text "InScope:" <+> in_scope_vars_doc
]
where
id_subst_doc = pprUniqFM ppr_id_subst (seIdSubst env)
ppr_id_subst (m_ar, sr) = arity_part <+> ppr sr
where arity_part = case m_ar of Just ar -> brackets $
text "join" <+> int ar
Nothing -> empty
in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
(vcat . map ppr_one)
ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
| otherwise = ppr v
type SimplIdSubst = IdEnv (Maybe JoinArity, SimplSR)
data SimplSR
= DoneEx OutExpr
| DoneId OutId
| ContEx TvSubstEnv
CvSubstEnv
SimplIdSubst
InExpr
instance Outputable SimplSR where
ppr (DoneEx e) = text "DoneEx" <+> ppr e
ppr (DoneId v) = text "DoneId" <+> ppr v
ppr (ContEx _tv _cv _id e) = vcat [text "ContEx" <+> ppr e ]
mkSimplEnv :: SimplifierMode -> SimplEnv
mkSimplEnv mode
= SimplEnv { seMode = mode
, seInScope = init_in_scope
, seFloats = emptyFloats
, seJoinFloats = emptyJoinFloats
, seTvSubst = emptyVarEnv
, seCvSubst = emptyVarEnv
, seIdSubst = emptyVarEnv }
init_in_scope :: InScopeSet
init_in_scope = mkInScopeSet (unitVarSet (mkWildValBinder unitTy))
getMode :: SimplEnv -> SimplifierMode
getMode env = seMode env
setMode :: SimplifierMode -> SimplEnv -> SimplEnv
setMode mode env = env { seMode = mode }
updMode :: (SimplifierMode -> SimplifierMode) -> SimplEnv -> SimplEnv
updMode upd env = env { seMode = upd (seMode env) }
extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
= ASSERT2( isId var && not (isCoVar var), ppr var )
env { seIdSubst = extendVarEnv subst var (isJoinId_maybe var, res) }
extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
= ASSERT( isTyVar var )
env {seTvSubst = extendVarEnv tsubst var res}
extendCvSubst :: SimplEnv -> CoVar -> Coercion -> SimplEnv
extendCvSubst env@(SimplEnv {seCvSubst = csubst}) var co
= ASSERT( isCoVar var )
env {seCvSubst = extendVarEnv csubst var co}
getInScope :: SimplEnv -> InScopeSet
getInScope env = seInScope env
setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
setInScopeSet env in_scope = env {seInScope = in_scope}
setInScopeAndZapFloats :: SimplEnv -> SimplEnv -> SimplEnv
setInScopeAndZapFloats env env_with_scope
= env { seInScope = seInScope env_with_scope,
seFloats = emptyFloats,
seJoinFloats = emptyJoinFloats }
setFloats :: SimplEnv -> SimplEnv -> SimplEnv
setFloats env env_with_floats
= env { seInScope = seInScope env_with_floats,
seFloats = seFloats env_with_floats,
seJoinFloats = seJoinFloats env_with_floats }
restoreJoinFloats :: SimplEnv -> SimplEnv -> SimplEnv
restoreJoinFloats env old_env
= env { seJoinFloats = seJoinFloats old_env }
addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
addNewInScopeIds env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }) vs
= env { seInScope = in_scope `extendInScopeSetList` vs,
seIdSubst = id_subst `delVarEnvList` vs }
modifyInScope :: SimplEnv -> CoreBndr -> SimplEnv
modifyInScope env@(SimplEnv {seInScope = in_scope}) v
= env {seInScope = extendInScopeSet in_scope v}
zapSubstEnv :: SimplEnv -> SimplEnv
zapSubstEnv env = env {seTvSubst = emptyVarEnv, seCvSubst = emptyVarEnv, seIdSubst = emptyVarEnv}
setSubstEnv :: SimplEnv -> TvSubstEnv -> CvSubstEnv -> SimplIdSubst -> SimplEnv
setSubstEnv env tvs cvs ids = env { seTvSubst = tvs, seCvSubst = cvs, seIdSubst = ids }
mkContEx :: SimplEnv -> InExpr -> SimplSR
mkContEx (SimplEnv { seTvSubst = tvs, seCvSubst = cvs, seIdSubst = ids }) e = ContEx tvs cvs ids e
data Floats = Floats (OrdList OutBind) FloatFlag
type JoinFloats = OrdList OutBind
data FloatFlag
= FltLifted
| FltOkSpec
| FltCareful
instance Outputable Floats where
ppr (Floats binds ff) = ppr ff $$ ppr (fromOL binds)
instance Outputable FloatFlag where
ppr FltLifted = text "FltLifted"
ppr FltOkSpec = text "FltOkSpec"
ppr FltCareful = text "FltCareful"
andFF :: FloatFlag -> FloatFlag -> FloatFlag
andFF FltCareful _ = FltCareful
andFF FltOkSpec FltCareful = FltCareful
andFF FltOkSpec _ = FltOkSpec
andFF FltLifted flt = flt
doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> OutExpr -> SimplEnv -> Bool
doFloatFromRhs lvl rec str rhs (SimplEnv {seFloats = Floats fs ff})
= not (isNilOL fs) && want_to_float && can_float
where
want_to_float = isTopLevel lvl || exprIsCheap rhs || exprIsExpandable rhs
can_float = case ff of
FltLifted -> True
FltOkSpec -> isNotTopLevel lvl && isNonRec rec
FltCareful -> isNotTopLevel lvl && isNonRec rec && str
emptyFloats :: Floats
emptyFloats = Floats nilOL FltLifted
emptyJoinFloats :: JoinFloats
emptyJoinFloats = nilOL
unitFloat :: OutBind -> Floats
unitFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
Floats (unitOL bind) (flag bind)
where
flag (Rec {}) = FltLifted
flag (NonRec bndr rhs)
| not (isStrictId bndr) = FltLifted
| exprIsLiteralString rhs = FltLifted
| exprOkForSpeculation rhs = FltOkSpec
| otherwise = ASSERT2( not (isUnliftedType (idType bndr)), ppr bndr )
FltCareful
unitJoinFloat :: OutBind -> JoinFloats
unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
unitOL bind
addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
addNonRec env id rhs
= id `seq`
env { seFloats = floats',
seJoinFloats = jfloats',
seInScope = extendInScopeSet (seInScope env) id }
where
bind = NonRec id rhs
floats' | isJoinId id = seFloats env
| otherwise = seFloats env `addFlts` unitFloat bind
jfloats' | isJoinId id = seJoinFloats env `addJoinFlts` unitJoinFloat bind
| otherwise = seJoinFloats env
extendFloats :: SimplEnv -> OutBind -> SimplEnv
extendFloats env bind
= ASSERT(all (not . isJoinId) (bindersOf bind))
env { seFloats = floats',
seJoinFloats = jfloats',
seInScope = extendInScopeSetList (seInScope env) bndrs }
where
bndrs = bindersOf bind
floats' | isJoinBind bind = seFloats env
| otherwise = seFloats env `addFlts` unitFloat bind
jfloats' | isJoinBind bind = seJoinFloats env `addJoinFlts`
unitJoinFloat bind
| otherwise = seJoinFloats env
addFloats :: SimplEnv -> SimplEnv -> SimplEnv
addFloats env1 env2
= env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2,
seInScope = seInScope env2 }
addFlts :: Floats -> Floats -> Floats
addFlts (Floats bs1 l1) (Floats bs2 l2)
= Floats (bs1 `appOL` bs2) (l1 `andFF` l2)
addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
addJoinFlts = appOL
zapFloats :: SimplEnv -> SimplEnv
zapFloats env = env { seFloats = emptyFloats
, seJoinFloats = emptyJoinFloats }
zapJoinFloats :: SimplEnv -> SimplEnv
zapJoinFloats env = env { seJoinFloats = emptyJoinFloats }
addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
,seJoinFloats = jbs })
= ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
env2 {seFloats = seFloats env1 `addFlts` floats'
,seJoinFloats = seJoinFloats env1 `addJoinFlts` jfloats'}
where
floats' | isNilOL bs = emptyFloats
| otherwise = unitFloat (Rec (flattenBinds (fromOL bs)))
jfloats' | isNilOL jbs = emptyJoinFloats
| otherwise = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
wrapFloats :: SimplEnv -> OutExpr -> OutExpr
wrapFloats env@(SimplEnv {seFloats = Floats bs _}) body
= foldrOL Let (wrapJoinFloats env body) bs
wrapJoinFloats :: SimplEnv -> OutExpr -> OutExpr
wrapJoinFloats (SimplEnv {seJoinFloats = jbs}) body
= foldrOL Let body jbs
getFloatBinds :: SimplEnv -> [CoreBind]
getFloatBinds env@(SimplEnv {seFloats = Floats bs _})
= fromOL bs ++ getJoinFloatBinds env
getJoinFloatBinds :: SimplEnv -> [CoreBind]
getJoinFloatBinds (SimplEnv {seJoinFloats = jbs})
= fromOL jbs
isEmptyFloats :: SimplEnv -> Bool
isEmptyFloats env@(SimplEnv {seFloats = Floats bs _})
= isNilOL bs && isEmptyJoinFloats env
isEmptyJoinFloats :: SimplEnv -> Bool
isEmptyJoinFloats (SimplEnv {seJoinFloats = jbs})
= isNilOL jbs
mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
mapFloats env@SimplEnv { seFloats = Floats fs ff, seJoinFloats = jfs } fun
= env { seFloats = Floats (mapOL app fs) ff
, seJoinFloats = mapOL app jfs }
where
app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
app (Rec bs) = Rec (map fun bs)
substId :: SimplEnv -> InId -> SimplSR
substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
= case snd <$> lookupVarEnv ids v of
Nothing -> DoneId (refineFromInScope in_scope v)
Just (DoneId v) -> DoneId (refineFromInScope in_scope v)
Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
Just res -> res
isJoinIdInEnv_maybe :: SimplEnv -> InId -> Maybe JoinArity
isJoinIdInEnv_maybe (SimplEnv { seInScope = inScope, seIdSubst = ids }) v
| not (isLocalId v) = Nothing
| Just (m_ar, _) <- lookupVarEnv ids v = m_ar
| Just v' <- lookupInScope inScope v = isJoinId_maybe v'
| otherwise = WARN( True , ppr v )
isJoinId_maybe v
refineFromInScope :: InScopeSet -> Var -> Var
refineFromInScope in_scope v
| isLocalId v = case lookupInScope in_scope v of
Just v' -> v'
Nothing -> WARN( True, ppr v ) v
| otherwise = v
lookupRecBndr :: SimplEnv -> InId -> OutId
lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
= case lookupVarEnv ids v of
Just (_, DoneId v) -> v
Just _ -> pprPanic "lookupRecBndr" (ppr v)
Nothing -> refineFromInScope in_scope v
simplBinders :: SimplEnv -> [InBndr] -> SimplM (SimplEnv, [OutBndr])
simplBinders env bndrs = mapAccumLM simplBinder env bndrs
simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplBinder env bndr
| isTyVar bndr = do { let (env', tv) = substTyVarBndr env bndr
; seqTyVar tv `seq` return (env', tv) }
| otherwise = do { let (env', id) = substIdBndr Nothing env bndr
; seqId id `seq` return (env', id) }
simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplNonRecBndr env id
= do { let (env1, id1) = substIdBndr Nothing env id
; seqId id1 `seq` return (env1, id1) }
simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
-> SimplM (SimplEnv, OutBndr)
simplNonRecJoinBndr env res_ty id
= do { let (env1, id1) = substIdBndr (Just res_ty) env id
; seqId id1 `seq` return (env1, id1) }
simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
simplRecBndrs env@(SimplEnv {}) ids
= ASSERT(all (not . isJoinId) ids)
do { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
; seqIds ids1 `seq` return env1 }
simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
= ASSERT(all isJoinId ids)
do { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
; seqIds ids1 `seq` return env1 }
substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
substIdBndr new_res_ty env bndr
| isCoVar bndr = substCoVarBndr env bndr
| otherwise = substNonCoVarIdBndr new_res_ty env bndr
substNonCoVarIdBndr
:: Maybe OutType
-> SimplEnv
-> InBndr
-> (SimplEnv, OutBndr)
substNonCoVarIdBndr new_res_ty
env@(SimplEnv { seInScope = in_scope
, seIdSubst = id_subst })
old_id
= ASSERT2( not (isCoVar old_id), ppr old_id )
(env { seInScope = in_scope `extendInScopeSet` new_id,
seIdSubst = new_subst }, new_id)
where
id1 = uniqAway in_scope old_id
id2 = substIdType env id1
id3 | Just res_ty <- new_res_ty
= id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
| otherwise
= id2
new_id = zapFragileIdInfo id3
new_subst | new_id /= old_id
= extendVarEnv id_subst old_id
(isJoinId_maybe new_id, DoneId new_id)
| otherwise
= delVarEnv id_subst old_id
seqTyVar :: TyVar -> ()
seqTyVar b = b `seq` ()
seqId :: Id -> ()
seqId id = seqType (idType id) `seq`
idInfo id `seq`
()
seqIds :: [Id] -> ()
seqIds [] = ()
seqIds (id:ids) = seqId id `seq` seqIds ids
getTCvSubst :: SimplEnv -> TCvSubst
getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
, seCvSubst = cv_env })
= mkTCvSubst in_scope (tv_env, cv_env)
substTy :: SimplEnv -> Type -> Type
substTy env ty = Type.substTy (getTCvSubst env) ty
substTyVar :: SimplEnv -> TyVar -> Type
substTyVar env tv = Type.substTyVar (getTCvSubst env) tv
substTyVarBndr :: SimplEnv -> TyVar -> (SimplEnv, TyVar)
substTyVarBndr env tv
= case Type.substTyVarBndr (getTCvSubst env) tv of
(TCvSubst in_scope' tv_env' cv_env', tv')
-> (env { seInScope = in_scope', seTvSubst = tv_env', seCvSubst = cv_env' }, tv')
substCoVar :: SimplEnv -> CoVar -> Coercion
substCoVar env tv = Coercion.substCoVar (getTCvSubst env) tv
substCoVarBndr :: SimplEnv -> CoVar -> (SimplEnv, CoVar)
substCoVarBndr env cv
= case Coercion.substCoVarBndr (getTCvSubst env) cv of
(TCvSubst in_scope' tv_env' cv_env', cv')
-> (env { seInScope = in_scope', seTvSubst = tv_env', seCvSubst = cv_env' }, cv')
substCo :: SimplEnv -> Coercion -> Coercion
substCo env co = Coercion.substCo (getTCvSubst env) co
substIdType :: SimplEnv -> Id -> Id
substIdType (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env }) id
| (isEmptyVarEnv tv_env && isEmptyVarEnv cv_env)
|| noFreeVarsOfType old_ty
= id
| otherwise = Id.setIdType id (Type.substTy (TCvSubst in_scope tv_env cv_env) old_ty)
where
old_ty = idType id