module GHC.Stg.DepAnal (depSortStgPgm) where
import GHC.Prelude
import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Name (Name, nameIsLocalOrFrom)
import GHC.Types.Name.Env
import GHC.Utils.Outputable
import GHC.Types.Unique.Set (nonDetEltsUniqSet)
import GHC.Types.Var.Set
import GHC.Unit.Module (Module)
import Data.Graph (SCC (..))
import Data.Bifunctor (first)
type BVs = VarSet
type FVs = VarSet
annTopBindingsDeps :: Module -> [StgTopBinding] -> [(StgTopBinding, FVs)]
annTopBindingsDeps this_mod bs = zip bs (map top_bind bs)
where
top_bind :: StgTopBinding -> FVs
top_bind StgTopStringLit{} =
emptyVarSet
top_bind (StgTopLifted bs) =
binding emptyVarSet bs
binding :: BVs -> StgBinding -> FVs
binding bounds (StgNonRec _ r) =
rhs bounds r
binding bounds (StgRec bndrs) =
unionVarSets $
map (bind_non_rec (extendVarSetList bounds (map fst bndrs))) bndrs
bind_non_rec :: BVs -> (Id, StgRhs) -> FVs
bind_non_rec bounds (_, r) =
rhs bounds r
rhs :: BVs -> StgRhs -> FVs
rhs bounds (StgRhsClosure _ _ _ as e) =
expr (extendVarSetList bounds as) e
rhs bounds (StgRhsCon _ _ as) =
args bounds as
var :: BVs -> Var -> FVs
var bounds v
| not (elemVarSet v bounds)
, nameIsLocalOrFrom this_mod (idName v)
= unitVarSet v
| otherwise
= emptyVarSet
arg :: BVs -> StgArg -> FVs
arg bounds (StgVarArg v) = var bounds v
arg _ StgLitArg{} = emptyVarSet
args :: BVs -> [StgArg] -> FVs
args bounds as = unionVarSets (map (arg bounds) as)
expr :: BVs -> StgExpr -> FVs
expr bounds (StgApp f as) =
var bounds f `unionVarSet` args bounds as
expr _ StgLit{} =
emptyVarSet
expr bounds (StgConApp _ as _) =
args bounds as
expr bounds (StgOpApp _ as _) =
args bounds as
expr _ lam@StgLam{} =
pprPanic "annTopBindingsDeps" (text "Found lambda:" $$ pprStgExpr panicStgPprOpts lam)
expr bounds (StgCase scrut scrut_bndr _ as) =
expr bounds scrut `unionVarSet`
alts (extendVarSet bounds scrut_bndr) as
expr bounds (StgLet _ bs e) =
binding bounds bs `unionVarSet`
expr (extendVarSetList bounds (bindersOf bs)) e
expr bounds (StgLetNoEscape _ bs e) =
binding bounds bs `unionVarSet`
expr (extendVarSetList bounds (bindersOf bs)) e
expr bounds (StgTick _ e) =
expr bounds e
alts :: BVs -> [StgAlt] -> FVs
alts bounds = unionVarSets . map (alt bounds)
alt :: BVs -> StgAlt -> FVs
alt bounds (_, bndrs, e) =
expr (extendVarSetList bounds bndrs) e
depSortStgPgm :: Module -> [StgTopBinding] -> [StgTopBinding]
depSortStgPgm this_mod =
map fst . depSort . annTopBindingsDeps this_mod
depSort :: [(StgTopBinding, FVs)] -> [(StgTopBinding, FVs)]
depSort = concatMap get_binds . depAnal defs uses
where
uses, defs :: (StgTopBinding, FVs) -> [Name]
uses (StgTopStringLit{}, _) = []
uses (StgTopLifted{}, fvs) = map idName (nonDetEltsUniqSet fvs)
defs (bind, _) = map idName (bindersOfTop bind)
get_binds (AcyclicSCC bind) =
[bind]
get_binds (CyclicSCC binds) =
pprPanic "depSortStgBinds"
(text "Found cyclic SCC:"
$$ ppr (map (first (pprStgTopBinding panicStgPprOpts)) binds))