{-# LANGUAGE PatternSynonyms #-}
module GHC.Core.Opt.Exitify ( exitifyProgram ) where
import GHC.Prelude
import GHC.Types.Var
import GHC.Types.Id
import GHC.Types.Id.Info
import GHC.Core
import GHC.Core.Utils
import GHC.Utils.Monad.State
import GHC.Types.Unique
import GHC.Types.Var.Set
import GHC.Types.Var.Env
import GHC.Core.FVs
import GHC.Data.FastString
import GHC.Core.Type
import GHC.Utils.Misc( mapSnd )
import Data.Bifunctor
import Control.Monad
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram CoreProgram
binds = (Bind JoinId -> Bind JoinId) -> CoreProgram -> CoreProgram
forall a b. (a -> b) -> [a] -> [b]
map Bind JoinId -> Bind JoinId
goTopLvl CoreProgram
binds
where
goTopLvl :: Bind JoinId -> Bind JoinId
goTopLvl (NonRec JoinId
v CoreExpr
e) = JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl CoreExpr
e)
goTopLvl (Rec [(JoinId, CoreExpr)]
pairs) = [(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec (((JoinId, CoreExpr) -> (JoinId, CoreExpr))
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreExpr -> CoreExpr) -> (JoinId, CoreExpr) -> (JoinId, CoreExpr)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl)) [(JoinId, CoreExpr)]
pairs)
in_scope_toplvl :: InScopeSet
in_scope_toplvl = InScopeSet
emptyInScopeSet InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` CoreProgram -> [JoinId]
forall b. [Bind b] -> [b]
bindersOfBinds CoreProgram
binds
go :: InScopeSet -> CoreExpr -> CoreExpr
go :: InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
_ e :: CoreExpr
e@(Var{}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Lit {}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Type {}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Coercion {}) = CoreExpr
e
go InScopeSet
in_scope (Cast CoreExpr
e' CoercionR
c) = CoreExpr -> CoercionR -> CoreExpr
forall b. Expr b -> CoercionR -> Expr b
Cast (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e') CoercionR
c
go InScopeSet
in_scope (Tick Tickish JoinId
t CoreExpr
e') = Tickish JoinId -> CoreExpr -> CoreExpr
forall b. Tickish JoinId -> Expr b -> Expr b
Tick Tickish JoinId
t (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e')
go InScopeSet
in_scope (App CoreExpr
e1 CoreExpr
e2) = CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e1) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e2)
go InScopeSet
in_scope (Lam JoinId
v CoreExpr
e')
= JoinId -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
e')
where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
v
go InScopeSet
in_scope (Case CoreExpr
scrut JoinId
bndr Type
ty [Alt JoinId]
alts)
= CoreExpr -> JoinId -> Type -> [Alt JoinId] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
scrut) JoinId
bndr Type
ty ((Alt JoinId -> Alt JoinId) -> [Alt JoinId] -> [Alt JoinId]
forall a b. (a -> b) -> [a] -> [b]
map Alt JoinId -> Alt JoinId
forall {a}. (a, [JoinId], CoreExpr) -> (a, [JoinId], CoreExpr)
go_alt [Alt JoinId]
alts)
where
in_scope1 :: InScopeSet
in_scope1 = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
go_alt :: (a, [JoinId], CoreExpr) -> (a, [JoinId], CoreExpr)
go_alt (a
dc, [JoinId]
pats, CoreExpr
rhs) = (a
dc, [JoinId]
pats, InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
rhs)
where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope1 InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
pats
go InScopeSet
in_scope (Let (NonRec JoinId
bndr CoreExpr
rhs) CoreExpr
body)
= Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
bndr (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
rhs)) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body)
where
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
go InScopeSet
in_scope (Let (Rec [(JoinId, CoreExpr)]
pairs) CoreExpr
body)
| Bool
is_join_rec = CoreProgram -> CoreExpr -> CoreExpr
forall b. [Bind b] -> Expr b -> Expr b
mkLets (InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope' [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
| Bool
otherwise = Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
where
is_join_rec :: Bool
is_join_rec = ((JoinId, CoreExpr) -> Bool) -> [(JoinId, CoreExpr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (JoinId -> Bool
isJoinId (JoinId -> Bool)
-> ((JoinId, CoreExpr) -> JoinId) -> (JoinId, CoreExpr) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst) [(JoinId, CoreExpr)]
pairs
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` Bind JoinId -> [JoinId]
forall b. Bind b -> [b]
bindersOf ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs)
pairs' :: [(JoinId, CoreExpr)]
pairs' = (CoreExpr -> CoreExpr)
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope') [(JoinId, CoreExpr)]
pairs
body' :: CoreExpr
body' = InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body
type ExitifyM = State [(JoinId, CoreExpr)]
exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
exitifyRec :: InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope [(JoinId, CoreExpr)]
pairs
= [ JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
xid CoreExpr
rhs | (JoinId
xid,CoreExpr
rhs) <- [(JoinId, CoreExpr)]
exits ] CoreProgram -> CoreProgram -> CoreProgram
forall a. [a] -> [a] -> [a]
++ [[(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs']
where
ann_pairs :: [(JoinId, CoreExprWithFVs)]
ann_pairs = ((JoinId, CoreExpr) -> (JoinId, CoreExprWithFVs))
-> [(JoinId, CoreExpr)] -> [(JoinId, CoreExprWithFVs)]
forall a b. (a -> b) -> [a] -> [b]
map ((CoreExpr -> CoreExprWithFVs)
-> (JoinId, CoreExpr) -> (JoinId, CoreExprWithFVs)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second CoreExpr -> CoreExprWithFVs
freeVars) [(JoinId, CoreExpr)]
pairs
recursive_calls :: VarSet
recursive_calls = [JoinId] -> VarSet
mkVarSet ([JoinId] -> VarSet) -> [JoinId] -> VarSet
forall a b. (a -> b) -> a -> b
$ ((JoinId, CoreExpr) -> JoinId) -> [(JoinId, CoreExpr)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
pairs
([(JoinId, CoreExpr)]
pairs',[(JoinId, CoreExpr)]
exits) = (State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
-> [(JoinId, CoreExpr)]
-> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)])
forall s a. State s a -> s -> (a, s)
`runState` []) (State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
-> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)]))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
-> ([(JoinId, CoreExpr)], [(JoinId, CoreExpr)])
forall a b. (a -> b) -> a -> b
$ do
[(JoinId, CoreExprWithFVs)]
-> ((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
ann_pairs (((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)])
-> ((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall a b. (a -> b) -> a -> b
$ \(JoinId
x,CoreExprWithFVs
rhs) -> do
let ([JoinId]
args, CoreExprWithFVs
body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs (JoinId -> Int
idJoinArity JoinId
x) CoreExprWithFVs
rhs
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
args CoreExprWithFVs
body
let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
args CoreExpr
body'
(JoinId, CoreExpr) -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
x, CoreExpr
rhs')
go :: [Var]
-> CoreExprWithFVs
-> ExitifyM CoreExpr
go :: [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
captured CoreExprWithFVs
ann_e
|
let fvs :: VarSet
fvs = DVarSet -> VarSet
dVarSetToVarSet (CoreExprWithFVs -> DVarSet
freeVarsOf CoreExprWithFVs
ann_e)
, VarSet -> VarSet -> Bool
disjointVarSet VarSet
fvs VarSet
recursive_calls
= [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e) VarSet
fvs
go [JoinId]
captured (DVarSet
_, AnnCase CoreExprWithFVs
scrut JoinId
bndr Type
ty [AnnAlt JoinId DVarSet]
alts) = do
[Alt JoinId]
alts' <- [AnnAlt JoinId DVarSet]
-> (AnnAlt JoinId DVarSet
-> State [(JoinId, CoreExpr)] (Alt JoinId))
-> State [(JoinId, CoreExpr)] [Alt JoinId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [AnnAlt JoinId DVarSet]
alts ((AnnAlt JoinId DVarSet -> State [(JoinId, CoreExpr)] (Alt JoinId))
-> State [(JoinId, CoreExpr)] [Alt JoinId])
-> (AnnAlt JoinId DVarSet
-> State [(JoinId, CoreExpr)] (Alt JoinId))
-> State [(JoinId, CoreExpr)] [Alt JoinId]
forall a b. (a -> b) -> a -> b
$ \(AltCon
dc, [JoinId]
pats, CoreExprWithFVs
rhs) -> do
CoreExpr
rhs' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId
bndr] [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
pats) CoreExprWithFVs
rhs
Alt JoinId -> State [(JoinId, CoreExpr)] (Alt JoinId)
forall (m :: * -> *) a. Monad m => a -> m a
return (AltCon
dc, [JoinId]
pats, CoreExpr
rhs')
CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> JoinId -> Type -> [Alt JoinId] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
scrut) JoinId
bndr Type
ty [Alt JoinId]
alts'
go [JoinId]
captured (DVarSet
_, AnnLet AnnBind JoinId DVarSet
ann_bind CoreExprWithFVs
body)
| AnnNonRec JoinId
j CoreExprWithFVs
rhs <- AnnBind JoinId DVarSet
ann_bind
, Just Int
join_arity <- JoinId -> Maybe Int
isJoinId_maybe JoinId
j
= do let ([JoinId]
params, CoreExprWithFVs
join_body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId
j]) CoreExprWithFVs
body
CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (JoinId -> CoreExpr -> Bind JoinId
forall b. b -> Expr b -> Bind b
NonRec JoinId
j CoreExpr
rhs') CoreExpr
body'
| AnnRec [(JoinId, CoreExprWithFVs)]
pairs <- AnnBind JoinId DVarSet
ann_bind
, JoinId -> Bool
isJoinId ((JoinId, CoreExprWithFVs) -> JoinId
forall a b. (a, b) -> a
fst ([(JoinId, CoreExprWithFVs)] -> (JoinId, CoreExprWithFVs)
forall a. [a] -> a
head [(JoinId, CoreExprWithFVs)]
pairs))
= do let js :: [JoinId]
js = ((JoinId, CoreExprWithFVs) -> JoinId)
-> [(JoinId, CoreExprWithFVs)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExprWithFVs) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExprWithFVs)]
pairs
[(JoinId, CoreExpr)]
pairs' <- [(JoinId, CoreExprWithFVs)]
-> ((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
pairs (((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)])
-> ((JoinId, CoreExprWithFVs)
-> State [(JoinId, CoreExpr)] (JoinId, CoreExpr))
-> State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall a b. (a -> b) -> a -> b
$ \(JoinId
j,CoreExprWithFVs
rhs) -> do
let join_arity :: Int
join_arity = JoinId -> Int
idJoinArity JoinId
j
([JoinId]
params, CoreExprWithFVs
join_body) = Int -> CoreExprWithFVs -> ([JoinId], CoreExprWithFVs)
forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
js [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
let rhs' :: CoreExpr
rhs' = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
(JoinId, CoreExpr) -> State [(JoinId, CoreExpr)] (JoinId, CoreExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
j, CoreExpr
rhs')
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ [JoinId]
js) CoreExprWithFVs
body
CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(JoinId, CoreExpr)] -> Bind JoinId
forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
| Bool
otherwise
= do CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured [JoinId] -> [JoinId] -> [JoinId]
forall a. [a] -> [a] -> [a]
++ Bind JoinId -> [JoinId]
forall b. Bind b -> [b]
bindersOf Bind JoinId
bind ) CoreExprWithFVs
body
CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ Bind JoinId -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let Bind JoinId
bind CoreExpr
body'
where bind :: Bind JoinId
bind = AnnBind JoinId DVarSet -> Bind JoinId
forall b annot. AnnBind b annot -> Bind b
deAnnBind AnnBind JoinId DVarSet
ann_bind
go [JoinId]
_ CoreExprWithFVs
ann_e = CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExprWithFVs -> CoreExpr
forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e)
go_exit :: [Var]
-> CoreExpr
-> VarSet
-> ExitifyM CoreExpr
go_exit :: [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured CoreExpr
e VarSet
fvs
| (Var JoinId
f, [CoreExpr]
args) <- CoreExpr -> (CoreExpr, [CoreExpr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
, JoinId -> Bool
isJoinId JoinId
f
, (CoreExpr -> Bool) -> [CoreExpr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all CoreExpr -> Bool
forall {b}. Expr b -> Bool
isCapturedVarArg [CoreExpr]
args
= CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool -> Bool
not Bool
is_interesting
= CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool
captures_join_points
= CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool
otherwise
= do {
let rhs :: CoreExpr
rhs = [JoinId] -> CoreExpr -> CoreExpr
forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
abs_vars CoreExpr
e
avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
captured
; JoinId
v <- InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
avoid ([JoinId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [JoinId]
abs_vars) CoreExpr
rhs
; CoreExpr -> ExitifyM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> ExitifyM CoreExpr) -> CoreExpr -> ExitifyM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> [JoinId] -> CoreExpr
forall b. Expr b -> [JoinId] -> Expr b
mkVarApps (JoinId -> CoreExpr
forall b. JoinId -> Expr b
Var JoinId
v) [JoinId]
abs_vars }
where
isCapturedVarArg :: Expr b -> Bool
isCapturedVarArg (Var JoinId
v) = JoinId
v JoinId -> [JoinId] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [JoinId]
captured
isCapturedVarArg Expr b
_ = Bool
False
is_interesting :: Bool
is_interesting = (JoinId -> Bool) -> VarSet -> Bool
anyVarSet JoinId -> Bool
isLocalId (VarSet -> Bool) -> VarSet -> Bool
forall a b. (a -> b) -> a -> b
$
VarSet
fvs VarSet -> VarSet -> VarSet
`minusVarSet` [JoinId] -> VarSet
mkVarSet [JoinId]
captured
abs_vars :: [JoinId]
abs_vars = (VarSet, [JoinId]) -> [JoinId]
forall a b. (a, b) -> b
snd ((VarSet, [JoinId]) -> [JoinId]) -> (VarSet, [JoinId]) -> [JoinId]
forall a b. (a -> b) -> a -> b
$ (JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId]))
-> (VarSet, [JoinId]) -> [JoinId] -> (VarSet, [JoinId])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick (VarSet
fvs, []) [JoinId]
captured
where
pick :: JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick JoinId
v (VarSet
fvs', [JoinId]
acc) | JoinId
v JoinId -> VarSet -> Bool
`elemVarSet` VarSet
fvs' = (VarSet
fvs' VarSet -> JoinId -> VarSet
`delVarSet` JoinId
v, JoinId -> JoinId
zap JoinId
v JoinId -> [JoinId] -> [JoinId]
forall a. a -> [a] -> [a]
: [JoinId]
acc)
| Bool
otherwise = (VarSet
fvs', [JoinId]
acc)
zap :: JoinId -> JoinId
zap JoinId
v | JoinId -> Bool
isId JoinId
v = JoinId -> IdInfo -> JoinId
setIdInfo JoinId
v IdInfo
vanillaIdInfo
| Bool
otherwise = JoinId
v
captures_join_points :: Bool
captures_join_points = (JoinId -> Bool) -> [JoinId] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any JoinId -> Bool
isJoinId [JoinId]
abs_vars
mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
mkExitJoinId :: InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity = do
[(JoinId, CoreExpr)]
fs <- State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall s. State s s
get
let avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` (((JoinId, CoreExpr) -> JoinId) -> [(JoinId, CoreExpr)] -> [JoinId]
forall a b. (a -> b) -> [a] -> [b]
map (JoinId, CoreExpr) -> JoinId
forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
fs)
InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
exit_id_tmpl
JoinId -> ExitifyM JoinId
forall (m :: * -> *) a. Monad m => a -> m a
return (InScopeSet -> JoinId -> JoinId
uniqAway InScopeSet
avoid JoinId
exit_id_tmpl)
where
exit_id_tmpl :: JoinId
exit_id_tmpl = FastString -> Unique -> Type -> Type -> JoinId
mkSysLocal (String -> FastString
fsLit String
"exit") Unique
initExitJoinUnique Type
Many Type
ty
JoinId -> Int -> JoinId
`asJoinId` Int
join_arity
addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
addExit :: InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
in_scope Int
join_arity CoreExpr
rhs = do
let ty :: Type
ty = CoreExpr -> Type
exprType CoreExpr
rhs
JoinId
v <- InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity
[(JoinId, CoreExpr)]
fs <- State [(JoinId, CoreExpr)] [(JoinId, CoreExpr)]
forall s. State s s
get
[(JoinId, CoreExpr)] -> State [(JoinId, CoreExpr)] ()
forall s. s -> State s ()
put ((JoinId
v,CoreExpr
rhs)(JoinId, CoreExpr) -> [(JoinId, CoreExpr)] -> [(JoinId, CoreExpr)]
forall a. a -> [a] -> [a]
:[(JoinId, CoreExpr)]
fs)
JoinId -> ExitifyM JoinId
forall (m :: * -> *) a. Monad m => a -> m a
return JoinId
v