{-# LANGUAGE CPP #-}

module GHC.Stg.DepAnal (depSortStgPgm) where

import GHC.Prelude

import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Name (Name, nameIsLocalOrFrom)
import GHC.Types.Name.Env
import GHC.Utils.Outputable
import GHC.Types.Unique.Set (nonDetEltsUniqSet)
import GHC.Types.Var.Set
import GHC.Unit.Module (Module)

import Data.Graph (SCC (..))
import Data.Bifunctor (first)

--------------------------------------------------------------------------------
-- * Dependency analysis

-- | Set of bound variables
type BVs = VarSet

-- | Set of free variables
type FVs = VarSet

-- | Dependency analysis on STG terms.
--
-- Dependencies of a binding are just free variables in the binding. This
-- includes imported ids and ids in the current module. For recursive groups we
-- just return one set of free variables which is just the union of dependencies
-- of all bindings in the group.
--
-- Implementation: pass bound variables (BVs) to recursive calls, get free
-- variables (FVs) back. We ignore imported FVs as they do not change the
-- ordering but it improves performance.
--
annTopBindingsDeps :: Module -> [StgTopBinding] -> [(StgTopBinding, FVs)]
annTopBindingsDeps :: Module -> [StgTopBinding] -> [(StgTopBinding, FVs)]
annTopBindingsDeps Module
this_mod [StgTopBinding]
bs = [StgTopBinding] -> [FVs] -> [(StgTopBinding, FVs)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StgTopBinding]
bs ((StgTopBinding -> FVs) -> [StgTopBinding] -> [FVs]
forall a b. (a -> b) -> [a] -> [b]
map StgTopBinding -> FVs
top_bind [StgTopBinding]
bs)
  where
    top_bind :: StgTopBinding -> FVs
    top_bind :: StgTopBinding -> FVs
top_bind StgTopStringLit{} =
      FVs
emptyVarSet

    top_bind (StgTopLifted GenStgBinding 'Vanilla
bs) =
      FVs -> GenStgBinding 'Vanilla -> FVs
binding FVs
emptyVarSet GenStgBinding 'Vanilla
bs

    binding :: BVs -> StgBinding -> FVs
    binding :: FVs -> GenStgBinding 'Vanilla -> FVs
binding FVs
bounds (StgNonRec BinderP 'Vanilla
_ GenStgRhs 'Vanilla
r) =
      FVs -> GenStgRhs 'Vanilla -> FVs
rhs FVs
bounds GenStgRhs 'Vanilla
r
    binding FVs
bounds (StgRec [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
bndrs) =
      [FVs] -> FVs
unionVarSets ([FVs] -> FVs) -> [FVs] -> FVs
forall a b. (a -> b) -> a -> b
$
        ((Id, GenStgRhs 'Vanilla) -> FVs)
-> [(Id, GenStgRhs 'Vanilla)] -> [FVs]
forall a b. (a -> b) -> [a] -> [b]
map (FVs -> (Id, GenStgRhs 'Vanilla) -> FVs
bind_non_rec (FVs -> [Id] -> FVs
extendVarSetList FVs
bounds (((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
bndrs))) [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
bndrs

    bind_non_rec :: BVs -> (Id, StgRhs) -> FVs
    bind_non_rec :: FVs -> (Id, GenStgRhs 'Vanilla) -> FVs
bind_non_rec FVs
bounds (Id
_, GenStgRhs 'Vanilla
r) =
        FVs -> GenStgRhs 'Vanilla -> FVs
rhs FVs
bounds GenStgRhs 'Vanilla
r

    rhs :: BVs -> StgRhs -> FVs
    rhs :: FVs -> GenStgRhs 'Vanilla -> FVs
rhs FVs
bounds (StgRhsClosure XRhsClosure 'Vanilla
_ CostCentreStack
_ UpdateFlag
_ [BinderP 'Vanilla]
as GenStgExpr 'Vanilla
e) =
      FVs -> GenStgExpr 'Vanilla -> FVs
expr (FVs -> [Id] -> FVs
extendVarSetList FVs
bounds [Id]
[BinderP 'Vanilla]
as) GenStgExpr 'Vanilla
e

    rhs FVs
bounds (StgRhsCon CostCentreStack
_ DataCon
_ [StgArg]
as) =
      FVs -> [StgArg] -> FVs
args FVs
bounds [StgArg]
as

    var :: BVs -> Var -> FVs
    var :: FVs -> Id -> FVs
var FVs
bounds Id
v
      | Bool -> Bool
not (Id -> FVs -> Bool
elemVarSet Id
v FVs
bounds)
      , Module -> Name -> Bool
nameIsLocalOrFrom Module
this_mod (Id -> Name
idName Id
v)
      = Id -> FVs
unitVarSet Id
v
      | Bool
otherwise
      = FVs
emptyVarSet

    arg :: BVs -> StgArg -> FVs
    arg :: FVs -> StgArg -> FVs
arg FVs
bounds (StgVarArg Id
v) = FVs -> Id -> FVs
var FVs
bounds Id
v
    arg FVs
_ StgLitArg{} = FVs
emptyVarSet

    args :: BVs -> [StgArg] -> FVs
    args :: FVs -> [StgArg] -> FVs
args FVs
bounds [StgArg]
as = [FVs] -> FVs
unionVarSets ((StgArg -> FVs) -> [StgArg] -> [FVs]
forall a b. (a -> b) -> [a] -> [b]
map (FVs -> StgArg -> FVs
arg FVs
bounds) [StgArg]
as)

    expr :: BVs -> StgExpr -> FVs
    expr :: FVs -> GenStgExpr 'Vanilla -> FVs
expr FVs
bounds (StgApp Id
f [StgArg]
as) =
      FVs -> Id -> FVs
var FVs
bounds Id
f FVs -> FVs -> FVs
`unionVarSet` FVs -> [StgArg] -> FVs
args FVs
bounds [StgArg]
as

    expr FVs
_ StgLit{} =
      FVs
emptyVarSet

    expr FVs
bounds (StgConApp DataCon
_ [StgArg]
as [Type]
_) =
      FVs -> [StgArg] -> FVs
args FVs
bounds [StgArg]
as
    expr FVs
bounds (StgOpApp StgOp
_ [StgArg]
as Type
_) =
      FVs -> [StgArg] -> FVs
args FVs
bounds [StgArg]
as
    expr FVs
_ lam :: GenStgExpr 'Vanilla
lam@StgLam{} =
      String -> SDoc -> FVs
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"annTopBindingsDeps" (String -> SDoc
text String
"Found lambda:" SDoc -> SDoc -> SDoc
$$ StgPprOpts -> GenStgExpr 'Vanilla -> SDoc
forall (pass :: StgPass).
OutputablePass pass =>
StgPprOpts -> GenStgExpr pass -> SDoc
pprStgExpr StgPprOpts
panicStgPprOpts GenStgExpr 'Vanilla
lam)
    expr FVs
bounds (StgCase GenStgExpr 'Vanilla
scrut BinderP 'Vanilla
scrut_bndr AltType
_ [GenStgAlt 'Vanilla]
as) =
      FVs -> GenStgExpr 'Vanilla -> FVs
expr FVs
bounds GenStgExpr 'Vanilla
scrut FVs -> FVs -> FVs
`unionVarSet`
        FVs -> [GenStgAlt 'Vanilla] -> FVs
alts (FVs -> Id -> FVs
extendVarSet FVs
bounds Id
BinderP 'Vanilla
scrut_bndr) [GenStgAlt 'Vanilla]
as
    expr FVs
bounds (StgLet XLet 'Vanilla
_ GenStgBinding 'Vanilla
bs GenStgExpr 'Vanilla
e) =
      FVs -> GenStgBinding 'Vanilla -> FVs
binding FVs
bounds GenStgBinding 'Vanilla
bs FVs -> FVs -> FVs
`unionVarSet`
        FVs -> GenStgExpr 'Vanilla -> FVs
expr (FVs -> [Id] -> FVs
extendVarSetList FVs
bounds (GenStgBinding 'Vanilla -> [Id]
forall (a :: StgPass). (BinderP a ~ Id) => GenStgBinding a -> [Id]
bindersOf GenStgBinding 'Vanilla
bs)) GenStgExpr 'Vanilla
e
    expr FVs
bounds (StgLetNoEscape XLetNoEscape 'Vanilla
_ GenStgBinding 'Vanilla
bs GenStgExpr 'Vanilla
e) =
      FVs -> GenStgBinding 'Vanilla -> FVs
binding FVs
bounds GenStgBinding 'Vanilla
bs FVs -> FVs -> FVs
`unionVarSet`
        FVs -> GenStgExpr 'Vanilla -> FVs
expr (FVs -> [Id] -> FVs
extendVarSetList FVs
bounds (GenStgBinding 'Vanilla -> [Id]
forall (a :: StgPass). (BinderP a ~ Id) => GenStgBinding a -> [Id]
bindersOf GenStgBinding 'Vanilla
bs)) GenStgExpr 'Vanilla
e

    expr FVs
bounds (StgTick Tickish Id
_ GenStgExpr 'Vanilla
e) =
      FVs -> GenStgExpr 'Vanilla -> FVs
expr FVs
bounds GenStgExpr 'Vanilla
e

    alts :: BVs -> [StgAlt] -> FVs
    alts :: FVs -> [GenStgAlt 'Vanilla] -> FVs
alts FVs
bounds = [FVs] -> FVs
unionVarSets ([FVs] -> FVs)
-> ([(AltCon, [Id], GenStgExpr 'Vanilla)] -> [FVs])
-> [(AltCon, [Id], GenStgExpr 'Vanilla)]
-> FVs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((AltCon, [Id], GenStgExpr 'Vanilla) -> FVs)
-> [(AltCon, [Id], GenStgExpr 'Vanilla)] -> [FVs]
forall a b. (a -> b) -> [a] -> [b]
map (FVs -> GenStgAlt 'Vanilla -> FVs
alt FVs
bounds)

    alt :: BVs -> StgAlt -> FVs
    alt :: FVs -> GenStgAlt 'Vanilla -> FVs
alt FVs
bounds (AltCon
_, [BinderP 'Vanilla]
bndrs, GenStgExpr 'Vanilla
e) =
      FVs -> GenStgExpr 'Vanilla -> FVs
expr (FVs -> [Id] -> FVs
extendVarSetList FVs
bounds [Id]
[BinderP 'Vanilla]
bndrs) GenStgExpr 'Vanilla
e

--------------------------------------------------------------------------------
-- * Dependency sorting

-- | Dependency sort a STG program so that dependencies come before uses.
depSortStgPgm :: Module -> [StgTopBinding] -> [StgTopBinding]
depSortStgPgm :: Module -> [StgTopBinding] -> [StgTopBinding]
depSortStgPgm Module
this_mod =
    {-# SCC "STG.depSort" #-}
    ((StgTopBinding, FVs) -> StgTopBinding)
-> [(StgTopBinding, FVs)] -> [StgTopBinding]
forall a b. (a -> b) -> [a] -> [b]
map (StgTopBinding, FVs) -> StgTopBinding
forall a b. (a, b) -> a
fst ([(StgTopBinding, FVs)] -> [StgTopBinding])
-> ([StgTopBinding] -> [(StgTopBinding, FVs)])
-> [StgTopBinding]
-> [StgTopBinding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(StgTopBinding, FVs)] -> [(StgTopBinding, FVs)]
depSort ([(StgTopBinding, FVs)] -> [(StgTopBinding, FVs)])
-> ([StgTopBinding] -> [(StgTopBinding, FVs)])
-> [StgTopBinding]
-> [(StgTopBinding, FVs)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module -> [StgTopBinding] -> [(StgTopBinding, FVs)]
annTopBindingsDeps Module
this_mod

-- | Sort free-variable-annotated STG bindings so that dependencies come before
-- uses.
depSort :: [(StgTopBinding, FVs)] -> [(StgTopBinding, FVs)]
depSort :: [(StgTopBinding, FVs)] -> [(StgTopBinding, FVs)]
depSort = (SCC (StgTopBinding, FVs) -> [(StgTopBinding, FVs)])
-> [SCC (StgTopBinding, FVs)] -> [(StgTopBinding, FVs)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SCC (StgTopBinding, FVs) -> [(StgTopBinding, FVs)]
forall {p :: * -> * -> *} {c}.
(Outputable (p SDoc c), Bifunctor p) =>
SCC (p StgTopBinding c) -> [p StgTopBinding c]
get_binds ([SCC (StgTopBinding, FVs)] -> [(StgTopBinding, FVs)])
-> ([(StgTopBinding, FVs)] -> [SCC (StgTopBinding, FVs)])
-> [(StgTopBinding, FVs)]
-> [(StgTopBinding, FVs)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((StgTopBinding, FVs) -> [Name])
-> ((StgTopBinding, FVs) -> [Name])
-> [(StgTopBinding, FVs)]
-> [SCC (StgTopBinding, FVs)]
forall node.
(node -> [Name]) -> (node -> [Name]) -> [node] -> [SCC node]
depAnal (StgTopBinding, FVs) -> [Name]
defs (StgTopBinding, FVs) -> [Name]
uses
  where
    uses, defs :: (StgTopBinding, FVs) -> [Name]

    -- TODO (osa): I'm unhappy about two things in this code:
    --
    --     * Why do we need Name instead of Id for uses and dependencies?
    --     * Why do we need a [Name] instead of `Set Name`? Surely depAnal
    --       doesn't need any ordering.

    uses :: (StgTopBinding, FVs) -> [Name]
uses (StgTopStringLit{}, FVs
_) = []
    uses (StgTopLifted{}, FVs
fvs)  = (Id -> Name) -> [Id] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Name
idName (FVs -> [Id]
forall elt. UniqSet elt -> [elt]
nonDetEltsUniqSet FVs
fvs)

    defs :: (StgTopBinding, FVs) -> [Name]
defs (StgTopBinding
bind, FVs
_) = (Id -> Name) -> [Id] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Name
idName (StgTopBinding -> [Id]
forall (a :: StgPass).
(BinderP a ~ Id) =>
GenStgTopBinding a -> [Id]
bindersOfTop StgTopBinding
bind)

    get_binds :: SCC (p StgTopBinding c) -> [p StgTopBinding c]
get_binds (AcyclicSCC p StgTopBinding c
bind) =
      [p StgTopBinding c
bind]
    get_binds (CyclicSCC [p StgTopBinding c]
binds) =
      String -> SDoc -> [p StgTopBinding c]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"depSortStgBinds"
               (String -> SDoc
text String
"Found cyclic SCC:"
               SDoc -> SDoc -> SDoc
$$ [p SDoc c] -> SDoc
forall a. Outputable a => a -> SDoc
ppr ((p StgTopBinding c -> p SDoc c)
-> [p StgTopBinding c] -> [p SDoc c]
forall a b. (a -> b) -> [a] -> [b]
map ((StgTopBinding -> SDoc) -> p StgTopBinding c -> p SDoc c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (StgPprOpts -> StgTopBinding -> SDoc
pprStgTopBinding StgPprOpts
panicStgPprOpts)) [p StgTopBinding c]
binds))