{-# LANGUAGE TypeFamilies #-}
module StgCse (stgCse) where
import GhcPrelude
import DataCon
import Id
import StgSyn
import Outputable
import VarEnv
import CoreSyn (AltCon(..))
import Data.List (mapAccumL)
import Data.Maybe (fromMaybe)
import TrieMap
import NameEnv
import Control.Monad( (>=>) )
data StgArgMap a = SAM
{ sam_var :: DVarEnv a
, sam_lit :: LiteralMap a
}
instance TrieMap StgArgMap where
type Key StgArgMap = StgArg
emptyTM = SAM { sam_var = emptyTM
, sam_lit = emptyTM }
lookupTM (StgVarArg var) = sam_var >.> lkDFreeVar var
lookupTM (StgLitArg lit) = sam_lit >.> lookupTM lit
alterTM (StgVarArg var) f m = m { sam_var = sam_var m |> xtDFreeVar var f }
alterTM (StgLitArg lit) f m = m { sam_lit = sam_lit m |> alterTM lit f }
foldTM k m = foldTM k (sam_var m) . foldTM k (sam_lit m)
mapTM f (SAM {sam_var = varm, sam_lit = litm}) =
SAM { sam_var = mapTM f varm, sam_lit = mapTM f litm }
newtype ConAppMap a = CAM { un_cam :: DNameEnv (ListMap StgArgMap a) }
instance TrieMap ConAppMap where
type Key ConAppMap = (DataCon, [StgArg])
emptyTM = CAM emptyTM
lookupTM (dataCon, args) = un_cam >.> lkDNamed dataCon >=> lookupTM args
alterTM (dataCon, args) f m =
m { un_cam = un_cam m |> xtDNamed dataCon |>> alterTM args f }
foldTM k = un_cam >.> foldTM (foldTM k)
mapTM f = un_cam >.> mapTM (mapTM f) >.> CAM
data CseEnv = CseEnv
{ ce_conAppMap :: ConAppMap OutId
, ce_subst :: IdEnv OutId
, ce_bndrMap :: IdEnv OutId
, ce_in_scope :: InScopeSet
}
initEnv :: InScopeSet -> CseEnv
initEnv in_scope = CseEnv
{ ce_conAppMap = emptyTM
, ce_subst = emptyVarEnv
, ce_bndrMap = emptyVarEnv
, ce_in_scope = in_scope
}
envLookup :: DataCon -> [OutStgArg] -> CseEnv -> Maybe OutId
envLookup dataCon args env = lookupTM (dataCon, args') (ce_conAppMap env)
where args' = map go args
go (StgVarArg v ) = StgVarArg (fromMaybe v $ lookupVarEnv (ce_bndrMap env) v)
go (StgLitArg lit) = StgLitArg lit
addDataCon :: OutId -> DataCon -> [OutStgArg] -> CseEnv -> CseEnv
addDataCon _ _ [] env = env
addDataCon bndr dataCon args env = env { ce_conAppMap = new_env }
where
new_env = insertTM (dataCon, args) bndr (ce_conAppMap env)
forgetCse :: CseEnv -> CseEnv
forgetCse env = env { ce_conAppMap = emptyTM }
addSubst :: OutId -> OutId -> CseEnv -> CseEnv
addSubst from to env
= env { ce_subst = extendVarEnv (ce_subst env) from to }
addTrivCaseBndr :: OutId -> OutId -> CseEnv -> CseEnv
addTrivCaseBndr from to env
= env { ce_bndrMap = extendVarEnv (ce_bndrMap env) from to }
substArgs :: CseEnv -> [InStgArg] -> [OutStgArg]
substArgs env = map (substArg env)
substArg :: CseEnv -> InStgArg -> OutStgArg
substArg env (StgVarArg from) = StgVarArg (substVar env from)
substArg _ (StgLitArg lit) = StgLitArg lit
substVars :: CseEnv -> [InId] -> [OutId]
substVars env = map (substVar env)
substVar :: CseEnv -> InId -> OutId
substVar env id = fromMaybe id $ lookupVarEnv (ce_subst env) id
substBndr :: CseEnv -> InId -> (CseEnv, OutId)
substBndr env old_id
= (new_env, new_id)
where
new_id = uniqAway (ce_in_scope env) old_id
no_change = new_id == old_id
env' = env { ce_in_scope = ce_in_scope env `extendInScopeSet` new_id }
new_env | no_change = env' { ce_subst = extendVarEnv (ce_subst env) old_id new_id }
| otherwise = env'
substBndrs :: CseEnv -> [InVar] -> (CseEnv, [OutVar])
substBndrs env bndrs = mapAccumL substBndr env bndrs
substPairs :: CseEnv -> [(InVar, a)] -> (CseEnv, [(OutVar, a)])
substPairs env bndrs = mapAccumL go env bndrs
where go env (id, x) = let (env', id') = substBndr env id
in (env', (id', x))
stgCse :: [InStgTopBinding] -> [OutStgTopBinding]
stgCse binds = snd $ mapAccumL stgCseTopLvl emptyInScopeSet binds
stgCseTopLvl :: InScopeSet -> InStgTopBinding -> (InScopeSet, OutStgTopBinding)
stgCseTopLvl in_scope t@(StgTopStringLit _ _) = (in_scope, t)
stgCseTopLvl in_scope (StgTopLifted (StgNonRec bndr rhs))
= (in_scope'
, StgTopLifted (StgNonRec bndr (stgCseTopLvlRhs in_scope rhs)))
where in_scope' = in_scope `extendInScopeSet` bndr
stgCseTopLvl in_scope (StgTopLifted (StgRec eqs))
= ( in_scope'
, StgTopLifted (StgRec [ (bndr, stgCseTopLvlRhs in_scope' rhs) | (bndr, rhs) <- eqs ]))
where in_scope' = in_scope `extendInScopeSetList` [ bndr | (bndr, _) <- eqs ]
stgCseTopLvlRhs :: InScopeSet -> InStgRhs -> OutStgRhs
stgCseTopLvlRhs in_scope (StgRhsClosure ccs info occs upd args body)
= let body' = stgCseExpr (initEnv in_scope) body
in StgRhsClosure ccs info occs upd args body'
stgCseTopLvlRhs _ (StgRhsCon ccs dataCon args)
= StgRhsCon ccs dataCon args
stgCseExpr :: CseEnv -> InStgExpr -> OutStgExpr
stgCseExpr env (StgApp fun args)
= StgApp fun' args'
where fun' = substVar env fun
args' = substArgs env args
stgCseExpr _ (StgLit lit)
= StgLit lit
stgCseExpr env (StgOpApp op args tys)
= StgOpApp op args' tys
where args' = substArgs env args
stgCseExpr _ (StgLam _ _)
= pprPanic "stgCseExp" (text "StgLam")
stgCseExpr env (StgTick tick body)
= let body' = stgCseExpr env body
in StgTick tick body'
stgCseExpr env (StgCase scrut bndr ty alts)
= mkStgCase scrut' bndr' ty alts'
where
scrut' = stgCseExpr env scrut
(env1, bndr') = substBndr env bndr
env2 | StgApp trivial_scrut [] <- scrut' = addTrivCaseBndr bndr trivial_scrut env1
| otherwise = env1
alts' = map (stgCseAlt env2 bndr') alts
stgCseExpr env (StgConApp dataCon args tys)
| Just bndr' <- envLookup dataCon args' env
= StgApp bndr' []
| otherwise
= StgConApp dataCon args' tys
where args' = substArgs env args
stgCseExpr env (StgLet binds body)
= let (binds', env') = stgCseBind env binds
body' = stgCseExpr env' body
in mkStgLet StgLet binds' body'
stgCseExpr env (StgLetNoEscape binds body)
= let (binds', env') = stgCseBind env binds
body' = stgCseExpr env' body
in mkStgLet StgLetNoEscape binds' body'
stgCseAlt :: CseEnv -> OutId -> InStgAlt -> OutStgAlt
stgCseAlt env case_bndr (DataAlt dataCon, args, rhs)
= let (env1, args') = substBndrs env args
env2 = addDataCon case_bndr dataCon (map StgVarArg args') env1
rhs' = stgCseExpr env2 rhs
in (DataAlt dataCon, args', rhs')
stgCseAlt env _ (altCon, args, rhs)
= let (env1, args') = substBndrs env args
rhs' = stgCseExpr env1 rhs
in (altCon, args', rhs')
stgCseBind :: CseEnv -> InStgBinding -> (Maybe OutStgBinding, CseEnv)
stgCseBind env (StgNonRec b e)
= let (env1, b') = substBndr env b
in case stgCseRhs env1 b' e of
(Nothing, env2) -> (Nothing, env2)
(Just (b2,e'), env2) -> (Just (StgNonRec b2 e'), env2)
stgCseBind env (StgRec pairs)
= let (env1, pairs1) = substPairs env pairs
in case stgCsePairs env1 pairs1 of
([], env2) -> (Nothing, env2)
(pairs2, env2) -> (Just (StgRec pairs2), env2)
stgCsePairs :: CseEnv -> [(OutId, InStgRhs)] -> ([(OutId, OutStgRhs)], CseEnv)
stgCsePairs env [] = ([], env)
stgCsePairs env0 ((b,e):pairs)
= let (pairMB, env1) = stgCseRhs env0 b e
(pairs', env2) = stgCsePairs env1 pairs
in (pairMB `mbCons` pairs', env2)
where
mbCons = maybe id (:)
stgCseRhs :: CseEnv -> OutId -> InStgRhs -> (Maybe (OutId, OutStgRhs), CseEnv)
stgCseRhs env bndr (StgRhsCon ccs dataCon args)
| Just other_bndr <- envLookup dataCon args' env
= let env' = addSubst bndr other_bndr env
in (Nothing, env')
| otherwise
= let env' = addDataCon bndr dataCon args' env
pair = (bndr, StgRhsCon ccs dataCon args')
in (Just pair, env')
where args' = substArgs env args
stgCseRhs env bndr (StgRhsClosure ccs info occs upd args body)
= let (env1, args') = substBndrs env args
env2 = forgetCse env1
body' = stgCseExpr env2 body
in (Just (substVar env bndr, StgRhsClosure ccs info occs' upd args' body'), env)
where occs' = substVars env occs
mkStgCase :: StgExpr -> OutId -> AltType -> [StgAlt] -> StgExpr
mkStgCase scrut bndr ty alts | all isBndr alts = scrut
| otherwise = StgCase scrut bndr ty alts
where
isBndr (_, _, StgApp f []) = f == bndr
isBndr _ = False
mkStgLet :: (a -> b -> b) -> Maybe a -> b -> b
mkStgLet _ Nothing body = body
mkStgLet stgLet (Just binds) body = stgLet binds body