{-# LANGUAGE TypeFamilies #-}

{- |
Non-global free variable analysis on STG terms. This pass annotates
non-top-level closure bindings with captured variables. Global variables are not
captured. For example, in a top-level binding like (pseudo-STG)

    f = \[x,y] .
      let g = \[p] . reverse (x ++ p)
      in g y

In g, `reverse` and `(++)` are global variables so they're not considered free.
`p` is an argument, so `x` is the only actual free variable here. The annotated
version is thus:

    f = \[x,y] .
      let g = [x] \[p] . reverse (x ++ p)
      in g y

Note that non-top-level recursive bindings are also considered free within the
group:

    map = {} \r [f xs0]
      let {
        Rec {
          go = {f, go} \r [xs1]
            case xs1 of {
              [] -> [] [];
              : x xs2 ->
                  let { xs' = {go, xs2} \u [] go xs2; } in
                  let { x' = {f, x} \u [] f x; } in
                  : [x' xs'];
            };
        end Rec }
      } in go xs0;

Here go is free in its RHS.

Top-level closure bindings never capture variables as all of their free
variables are global.
-}
module GHC.Stg.FVs (
    annTopBindingsFreeVars,
    annBindingFreeVars
  ) where

import GHC.Prelude

import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Var.Set
import GHC.Types.Tickish ( GenTickish(Breakpoint) )
import GHC.Utils.Misc

import Data.Maybe ( mapMaybe )

newtype Env
  = Env
  { Env -> IdSet
locals :: IdSet
  }

emptyEnv :: Env
emptyEnv :: Env
emptyEnv = IdSet -> Env
Env IdSet
emptyVarSet

addLocals :: [Id] -> Env -> Env
addLocals :: [Var] -> Env -> Env
addLocals [Var]
bndrs Env
env
  = Env
env { locals :: IdSet
locals = IdSet -> [Var] -> IdSet
extendVarSetList (Env -> IdSet
locals Env
env) [Var]
bndrs }

-- | Annotates a top-level STG binding group with its free variables.
annTopBindingsFreeVars :: [StgTopBinding] -> [CgStgTopBinding]
annTopBindingsFreeVars :: [StgTopBinding] -> [CgStgTopBinding]
annTopBindingsFreeVars = forall a b. (a -> b) -> [a] -> [b]
map StgTopBinding -> CgStgTopBinding
go
  where
    go :: StgTopBinding -> CgStgTopBinding
go (StgTopStringLit Var
id ByteString
bs) = forall (pass :: StgPass).
Var -> ByteString -> GenStgTopBinding pass
StgTopStringLit Var
id ByteString
bs
    go (StgTopLifted GenStgBinding 'Vanilla
bind)
      = forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> CgStgBinding
annBindingFreeVars GenStgBinding 'Vanilla
bind)

-- | Annotates an STG binding with its free variables.
annBindingFreeVars :: StgBinding -> CgStgBinding
annBindingFreeVars :: GenStgBinding 'Vanilla -> CgStgBinding
annBindingFreeVars = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> DIdSet -> GenStgBinding 'Vanilla -> (CgStgBinding, DIdSet)
binding Env
emptyEnv DIdSet
emptyDVarSet

boundIds :: StgBinding -> [Id]
boundIds :: GenStgBinding 'Vanilla -> [Var]
boundIds (StgNonRec BinderP 'Vanilla
b GenStgRhs 'Vanilla
_) = [BinderP 'Vanilla
b]
boundIds (StgRec [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)  = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs

-- Note [Tracking local binders]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- 'locals' contains non-toplevel, non-imported binders.
-- We maintain the set in 'expr', 'alt' and 'rhs', which are the only
-- places where new local binders are introduced.
-- Why do it there rather than in 'binding'? Two reasons:
--
--   1. We call 'binding' from 'annTopBindingsFreeVars', which would
--      add top-level bindings to the 'locals' set.
--   2. In the let(-no-escape) case, we need to extend the environment
--      prior to analysing the body, but we also need the fvs from the
--      body to analyse the RHSs. No way to do this without some
--      knot-tying.

-- | This makes sure that only local, non-global free vars make it into the set.
mkFreeVarSet :: Env -> [Id] -> DIdSet
mkFreeVarSet :: Env -> [Var] -> DIdSet
mkFreeVarSet Env
env = [Var] -> DIdSet
mkDVarSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (Var -> IdSet -> Bool
`elemVarSet` Env -> IdSet
locals Env
env)

args :: Env -> [StgArg] -> DIdSet
args :: Env -> [StgArg] -> DIdSet
args Env
env = Env -> [Var] -> DIdSet
mkFreeVarSet Env
env forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe StgArg -> Maybe Var
f
  where
    f :: StgArg -> Maybe Var
f (StgVarArg Var
occ) = forall a. a -> Maybe a
Just Var
occ
    f StgArg
_               = forall a. Maybe a
Nothing

binding :: Env -> DIdSet -> StgBinding -> (CgStgBinding, DIdSet)
binding :: Env -> DIdSet -> GenStgBinding 'Vanilla -> (CgStgBinding, DIdSet)
binding Env
env DIdSet
body_fv (StgNonRec BinderP 'Vanilla
bndr GenStgRhs 'Vanilla
r) = (forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec BinderP 'Vanilla
bndr CgStgRhs
r', DIdSet
fvs)
  where
    -- See Note [Tracking local binders]
    (CgStgRhs
r', DIdSet
rhs_fvs) = Env -> GenStgRhs 'Vanilla -> (CgStgRhs, DIdSet)
rhs Env
env GenStgRhs 'Vanilla
r
    fvs :: DIdSet
fvs = DIdSet -> Var -> DIdSet
delDVarSet DIdSet
body_fv BinderP 'Vanilla
bndr DIdSet -> DIdSet -> DIdSet
`unionDVarSet` DIdSet
rhs_fvs
binding Env
env DIdSet
body_fv (StgRec [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = (forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec [(Var, CgStgRhs)]
pairs', DIdSet
fvs)
  where
    -- See Note [Tracking local binders]
    bndrs :: [Var]
bndrs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs
    ([CgStgRhs]
rhss, [DIdSet]
rhs_fvss) = forall a b c. (a -> (b, c)) -> [a] -> ([b], [c])
mapAndUnzip (Env -> GenStgRhs 'Vanilla -> (CgStgRhs, DIdSet)
rhs Env
env forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs
    pairs' :: [(Var, CgStgRhs)]
pairs' = forall a b. [a] -> [b] -> [(a, b)]
zip [Var]
bndrs [CgStgRhs]
rhss
    fvs :: DIdSet
fvs = DIdSet -> [Var] -> DIdSet
delDVarSetList ([DIdSet] -> DIdSet
unionDVarSets (DIdSet
body_fvforall a. a -> [a] -> [a]
:[DIdSet]
rhs_fvss)) [Var]
bndrs

expr :: Env -> StgExpr -> (CgStgExpr, DIdSet)
expr :: Env -> StgExpr -> (CgStgExpr, DIdSet)
expr Env
env = StgExpr -> (CgStgExpr, DIdSet)
go
  where
    go :: StgExpr -> (CgStgExpr, DIdSet)
go (StgApp Var
occ [StgArg]
as)
      = (forall (pass :: StgPass). Var -> [StgArg] -> GenStgExpr pass
StgApp Var
occ [StgArg]
as, DIdSet -> DIdSet -> DIdSet
unionDVarSet (Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as) (Env -> [Var] -> DIdSet
mkFreeVarSet Env
env [Var
occ]))
    go (StgLit Literal
lit) = (forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit, DIdSet
emptyDVarSet)
    go (StgConApp DataCon
dc XConApp 'Vanilla
n [StgArg]
as [Type]
tys) = (forall (pass :: StgPass).
DataCon -> XConApp pass -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
dc XConApp 'Vanilla
n [StgArg]
as [Type]
tys, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)
    go (StgOpApp StgOp
op [StgArg]
as Type
ty) = (forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
as Type
ty, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)
    go (StgCase StgExpr
scrut BinderP 'Vanilla
bndr AltType
ty [GenStgAlt 'Vanilla]
alts) = (forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase CgStgExpr
scrut' BinderP 'Vanilla
bndr AltType
ty [(AltCon, [Var], CgStgExpr)]
alts', DIdSet
fvs)
      where
        (CgStgExpr
scrut', DIdSet
scrut_fvs) = StgExpr -> (CgStgExpr, DIdSet)
go StgExpr
scrut
        -- See Note [Tracking local binders]
        ([(AltCon, [Var], CgStgExpr)]
alts', [DIdSet]
alt_fvss) = forall a b c. (a -> (b, c)) -> [a] -> ([b], [c])
mapAndUnzip (Env -> GenStgAlt 'Vanilla -> (CgStgAlt, DIdSet)
alt ([Var] -> Env -> Env
addLocals [BinderP 'Vanilla
bndr] Env
env)) [GenStgAlt 'Vanilla]
alts
        alt_fvs :: DIdSet
alt_fvs = [DIdSet] -> DIdSet
unionDVarSets [DIdSet]
alt_fvss
        fvs :: DIdSet
fvs = DIdSet -> Var -> DIdSet
delDVarSet (DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
scrut_fvs DIdSet
alt_fvs) BinderP 'Vanilla
bndr
    go (StgLet XLet 'Vanilla
ext GenStgBinding 'Vanilla
bind StgExpr
body) = (CgStgBinding -> CgStgExpr -> CgStgExpr)
-> GenStgBinding 'Vanilla -> StgExpr -> (CgStgExpr, DIdSet)
go_bind (forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
ext) GenStgBinding 'Vanilla
bind StgExpr
body
    go (StgLetNoEscape XLetNoEscape 'Vanilla
ext GenStgBinding 'Vanilla
bind StgExpr
body) = (CgStgBinding -> CgStgExpr -> CgStgExpr)
-> GenStgBinding 'Vanilla -> StgExpr -> (CgStgExpr, DIdSet)
go_bind (forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape XLetNoEscape 'Vanilla
ext) GenStgBinding 'Vanilla
bind StgExpr
body
    go (StgTick StgTickish
tick StgExpr
e) = (forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
tick CgStgExpr
e', DIdSet
fvs')
      where
        (CgStgExpr
e', DIdSet
fvs) = StgExpr -> (CgStgExpr, DIdSet)
go StgExpr
e
        fvs' :: DIdSet
fvs' = DIdSet -> DIdSet -> DIdSet
unionDVarSet (forall {pass :: TickishPass}.
(XTickishId pass ~ Var) =>
GenTickish pass -> DIdSet
tickish StgTickish
tick) DIdSet
fvs
        tickish :: GenTickish pass -> DIdSet
tickish (Breakpoint XBreakpoint pass
_ Int
_ [XTickishId pass]
ids) = [Var] -> DIdSet
mkDVarSet [XTickishId pass]
ids
        tickish GenTickish pass
_                    = DIdSet
emptyDVarSet

    go_bind :: (CgStgBinding -> CgStgExpr -> CgStgExpr)
-> GenStgBinding 'Vanilla -> StgExpr -> (CgStgExpr, DIdSet)
go_bind CgStgBinding -> CgStgExpr -> CgStgExpr
dc GenStgBinding 'Vanilla
bind StgExpr
body = (CgStgBinding -> CgStgExpr -> CgStgExpr
dc CgStgBinding
bind' CgStgExpr
body', DIdSet
fvs)
      where
        -- See Note [Tracking local binders]
        env' :: Env
env' = [Var] -> Env -> Env
addLocals (GenStgBinding 'Vanilla -> [Var]
boundIds GenStgBinding 'Vanilla
bind) Env
env
        (CgStgExpr
body', DIdSet
body_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr Env
env' StgExpr
body
        (CgStgBinding
bind', DIdSet
fvs) = Env -> DIdSet -> GenStgBinding 'Vanilla -> (CgStgBinding, DIdSet)
binding Env
env' DIdSet
body_fvs GenStgBinding 'Vanilla
bind

rhs :: Env -> StgRhs -> (CgStgRhs, DIdSet)
rhs :: Env -> GenStgRhs 'Vanilla -> (CgStgRhs, DIdSet)
rhs Env
env (StgRhsClosure XRhsClosure 'Vanilla
_ CostCentreStack
ccs UpdateFlag
uf [BinderP 'Vanilla]
bndrs StgExpr
body)
  = (forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure DIdSet
fvs CostCentreStack
ccs UpdateFlag
uf [BinderP 'Vanilla]
bndrs CgStgExpr
body', DIdSet
fvs)
  where
    -- See Note [Tracking local binders]
    (CgStgExpr
body', DIdSet
body_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr ([Var] -> Env -> Env
addLocals [BinderP 'Vanilla]
bndrs Env
env) StgExpr
body
    fvs :: DIdSet
fvs = DIdSet -> [Var] -> DIdSet
delDVarSetList DIdSet
body_fvs [BinderP 'Vanilla]
bndrs
rhs Env
env (StgRhsCon CostCentreStack
ccs DataCon
dc ConstructorNumber
mu [StgTickish]
ts [StgArg]
as) = (forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
dc ConstructorNumber
mu [StgTickish]
ts [StgArg]
as, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)

alt :: Env -> StgAlt -> (CgStgAlt, DIdSet)
alt :: Env -> GenStgAlt 'Vanilla -> (CgStgAlt, DIdSet)
alt Env
env (AltCon
con, [BinderP 'Vanilla]
bndrs, StgExpr
e) = ((AltCon
con, [BinderP 'Vanilla]
bndrs, CgStgExpr
e'), DIdSet
fvs)
  where
    -- See Note [Tracking local binders]
    (CgStgExpr
e', DIdSet
rhs_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr ([Var] -> Env -> Env
addLocals [BinderP 'Vanilla]
bndrs Env
env) StgExpr
e
    fvs :: DIdSet
fvs = DIdSet -> [Var] -> DIdSet
delDVarSetList DIdSet
rhs_fvs [BinderP 'Vanilla]
bndrs