{-|
  Prepare the STG for bytecode generation:

   - Ensure that all breakpoints are directly under
        a let-binding, introducing a new binding for
        those that aren't already.

   - Protect Not-necessarily lifted join points, see
        Note [Not-necessarily-lifted join points]

 -}

module GHC.Stg.BcPrep ( bcPrep ) where

import GHC.Prelude

import GHC.Types.Id.Make
import GHC.Types.Id
import GHC.Core.Type
import GHC.Builtin.Types ( unboxedUnitTy )
import GHC.Builtin.Types.Prim
import GHC.Types.Unique
import GHC.Data.FastString
import GHC.Utils.Panic.Plain
import GHC.Types.Tickish
import GHC.Types.Unique.Supply
import qualified GHC.Types.CostCentre as CC
import GHC.Stg.Syntax
import GHC.Utils.Monad.State.Strict

data BcPrepM_State
   = BcPrepM_State
        { BcPrepM_State -> UniqSupply
prepUniqSupply :: !UniqSupply      -- for generating fresh variable names
        }

type BcPrepM a = State BcPrepM_State a

bcPrepRHS :: StgRhs -> BcPrepM StgRhs
-- explicitly match all constructors so we get a warning if we miss any
bcPrepRHS :: StgRhs -> BcPrepM StgRhs
bcPrepRHS (StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args (StgTick bp :: StgTickish
bp@Breakpoint{} GenStgExpr 'Vanilla
expr) Type
typ) = do
  {- If we have a breakpoint directly under an StgRhsClosure we don't
     need to introduce a new binding for it.
   -}
  GenStgExpr 'Vanilla
expr' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
  StgRhs -> BcPrepM StgRhs
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> StgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args (StgTickish -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
expr') Type
typ)
bcPrepRHS (StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args GenStgExpr 'Vanilla
expr Type
typ) =
  XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> StgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
fvs CostCentreStack
cc UpdateFlag
upd [BinderP 'Vanilla]
args (GenStgExpr 'Vanilla -> Type -> StgRhs)
-> BcPrepM (GenStgExpr 'Vanilla)
-> State BcPrepM_State (Type -> StgRhs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr State BcPrepM_State (Type -> StgRhs)
-> State BcPrepM_State Type -> BcPrepM StgRhs
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> State BcPrepM_State Type
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
bcPrepRHS con :: StgRhs
con@StgRhsCon{} = StgRhs -> BcPrepM StgRhs
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StgRhs
con

bcPrepExpr :: StgExpr -> BcPrepM StgExpr
-- explicitly match all constructors so we get a warning if we miss any
bcPrepExpr :: GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr (StgTick bp :: StgTickish
bp@(Breakpoint XBreakpoint 'TickishPassStg
tick_ty Int
_ [XTickishId 'TickishPassStg]
_) GenStgExpr 'Vanilla
rhs)
  | Type -> Bool
isLiftedTypeKind (HasDebugCallStack => Type -> Type
Type -> Type
typeKind Type
XBreakpoint 'TickishPassStg
tick_ty) = do
      Id
id <- Type -> BcPrepM Id
newId Type
XBreakpoint 'TickishPassStg
tick_ty
      GenStgExpr 'Vanilla
rhs' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
      let expr' :: GenStgExpr 'Vanilla
expr' = StgTickish -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
rhs'
          bnd :: GenStgBinding 'Vanilla
bnd = BinderP 'Vanilla -> StgRhs -> GenStgBinding 'Vanilla
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
BinderP 'Vanilla
id (XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> StgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
NoExtFieldSilent
noExtFieldSilent
                                            CostCentreStack
CC.dontCareCCS
                                            UpdateFlag
ReEntrant
                                            []
                                            GenStgExpr 'Vanilla
expr'
                                            Type
XBreakpoint 'TickishPassStg
tick_ty
                             )
          letExp :: GenStgExpr 'Vanilla
letExp = XLet 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bnd (Id -> [StgArg] -> GenStgExpr 'Vanilla
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id [])
      GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
letExp
  | Bool
otherwise = do
      Id
id <- Type -> BcPrepM Id
newId (HasDebugCallStack => Type -> Type -> Type
Type -> Type -> Type
mkVisFunTyMany Type
realWorldStatePrimTy Type
XBreakpoint 'TickishPassStg
tick_ty)
      GenStgExpr 'Vanilla
rhs' <- GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
      let expr' :: GenStgExpr 'Vanilla
expr' = StgTickish -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
bp GenStgExpr 'Vanilla
rhs'
          bnd :: GenStgBinding 'Vanilla
bnd = BinderP 'Vanilla -> StgRhs -> GenStgBinding 'Vanilla
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
BinderP 'Vanilla
id (XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> StgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
NoExtFieldSilent
noExtFieldSilent
                                            CostCentreStack
CC.dontCareCCS
                                            UpdateFlag
ReEntrant
                                            [Id
BinderP 'Vanilla
voidArgId]
                                            GenStgExpr 'Vanilla
expr'
                                            Type
XBreakpoint 'TickishPassStg
tick_ty
                             )
      GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla))
-> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a b. (a -> b) -> a -> b
$ XLet 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bnd (Id -> [StgArg] -> GenStgExpr 'Vanilla
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id [Id -> StgArg
StgVarArg Id
realWorldPrimId])
bcPrepExpr (StgTick StgTickish
tick GenStgExpr 'Vanilla
rhs) =
  StgTickish -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
tick (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> BcPrepM (GenStgExpr 'Vanilla) -> BcPrepM (GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs
bcPrepExpr (StgLet XLet 'Vanilla
xlet GenStgBinding 'Vanilla
bnds GenStgExpr 'Vanilla
expr) =
  XLet 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
xlet (GenStgBinding 'Vanilla
 -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> State BcPrepM_State (GenStgBinding 'Vanilla)
-> State BcPrepM_State (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla
-> State BcPrepM_State (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnds
              State BcPrepM_State (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> BcPrepM (GenStgExpr 'Vanilla) -> BcPrepM (GenStgExpr 'Vanilla)
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
bcPrepExpr (StgLetNoEscape XLetNoEscape 'Vanilla
xlne GenStgBinding 'Vanilla
bnds GenStgExpr 'Vanilla
expr) =
  XLet 'Vanilla
-> GenStgBinding 'Vanilla
-> GenStgExpr 'Vanilla
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLetNoEscape 'Vanilla
XLet 'Vanilla
xlne (GenStgBinding 'Vanilla
 -> GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> State BcPrepM_State (GenStgBinding 'Vanilla)
-> State BcPrepM_State (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla
-> State BcPrepM_State (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnds
              State BcPrepM_State (GenStgExpr 'Vanilla -> GenStgExpr 'Vanilla)
-> BcPrepM (GenStgExpr 'Vanilla) -> BcPrepM (GenStgExpr 'Vanilla)
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
bcPrepExpr (StgCase GenStgExpr 'Vanilla
expr BinderP 'Vanilla
bndr AltType
alt_type [GenStgAlt 'Vanilla]
alts) =
  GenStgExpr 'Vanilla
-> Id -> AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla
GenStgExpr 'Vanilla
-> BinderP 'Vanilla
-> AltType
-> [GenStgAlt 'Vanilla]
-> GenStgExpr 'Vanilla
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase (GenStgExpr 'Vanilla
 -> Id -> AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
-> BcPrepM (GenStgExpr 'Vanilla)
-> State
     BcPrepM_State
     (Id -> AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
expr
          State
  BcPrepM_State
  (Id -> AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
-> BcPrepM Id
-> State
     BcPrepM_State
     (AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Id -> BcPrepM Id
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Id
BinderP 'Vanilla
bndr
          State
  BcPrepM_State
  (AltType -> [GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
-> State BcPrepM_State AltType
-> State
     BcPrepM_State ([GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> AltType -> State BcPrepM_State AltType
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type
          State BcPrepM_State ([GenStgAlt 'Vanilla] -> GenStgExpr 'Vanilla)
-> State BcPrepM_State [GenStgAlt 'Vanilla]
-> BcPrepM (GenStgExpr 'Vanilla)
forall a b.
State BcPrepM_State (a -> b)
-> State BcPrepM_State a -> State BcPrepM_State b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (GenStgAlt 'Vanilla -> State BcPrepM_State (GenStgAlt 'Vanilla))
-> [GenStgAlt 'Vanilla] -> State BcPrepM_State [GenStgAlt 'Vanilla]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgAlt 'Vanilla -> State BcPrepM_State (GenStgAlt 'Vanilla)
bcPrepAlt [GenStgAlt 'Vanilla]
alts
bcPrepExpr lit :: GenStgExpr 'Vanilla
lit@StgLit{} = GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
lit
-- See Note [Not-necessarily-lifted join points], step 3.
bcPrepExpr (StgApp Id
x [])
  | Id -> Bool
isNNLJoinPoint Id
x = GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla))
-> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a b. (a -> b) -> a -> b
$
      Id -> [StgArg] -> GenStgExpr 'Vanilla
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp (Id -> Id
protectNNLJoinPointId Id
x) [Id -> StgArg
StgVarArg Id
voidPrimId]
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgApp{} = GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgConApp{} = GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app
bcPrepExpr app :: GenStgExpr 'Vanilla
app@StgOpApp{} = GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenStgExpr 'Vanilla
app

bcPrepAlt :: StgAlt -> BcPrepM StgAlt
bcPrepAlt :: GenStgAlt 'Vanilla -> State BcPrepM_State (GenStgAlt 'Vanilla)
bcPrepAlt (GenStgAlt AltCon
con [BinderP 'Vanilla]
bndrs GenStgExpr 'Vanilla
rhs) = AltCon
-> [BinderP 'Vanilla] -> GenStgExpr 'Vanilla -> GenStgAlt 'Vanilla
forall (pass :: StgPass).
AltCon -> [BinderP pass] -> GenStgExpr pass -> GenStgAlt pass
GenStgAlt AltCon
con [BinderP 'Vanilla]
bndrs (GenStgExpr 'Vanilla -> GenStgAlt 'Vanilla)
-> BcPrepM (GenStgExpr 'Vanilla)
-> State BcPrepM_State (GenStgAlt 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'Vanilla -> BcPrepM (GenStgExpr 'Vanilla)
bcPrepExpr GenStgExpr 'Vanilla
rhs

bcPrepBind :: StgBinding -> BcPrepM StgBinding
-- explicitly match all constructors so we get a warning if we miss any
bcPrepBind :: GenStgBinding 'Vanilla
-> State BcPrepM_State (GenStgBinding 'Vanilla)
bcPrepBind (StgNonRec BinderP 'Vanilla
bndr StgRhs
rhs) =
  let (Id
bndr', StgRhs
rhs') = (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind (Id
BinderP 'Vanilla
bndr, StgRhs
rhs)
  in  BinderP 'Vanilla -> StgRhs -> GenStgBinding 'Vanilla
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec Id
BinderP 'Vanilla
bndr' (StgRhs -> GenStgBinding 'Vanilla)
-> BcPrepM StgRhs -> State BcPrepM_State (GenStgBinding 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StgRhs -> BcPrepM StgRhs
bcPrepRHS StgRhs
rhs'
bcPrepBind (StgRec [(BinderP 'Vanilla, StgRhs)]
bnds) =
  [(Id, StgRhs)] -> GenStgBinding 'Vanilla
[(BinderP 'Vanilla, StgRhs)] -> GenStgBinding 'Vanilla
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ([(Id, StgRhs)] -> GenStgBinding 'Vanilla)
-> State BcPrepM_State [(Id, StgRhs)]
-> State BcPrepM_State (GenStgBinding 'Vanilla)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Id, StgRhs) -> State BcPrepM_State (Id, StgRhs))
-> [(Id, StgRhs)] -> State BcPrepM_State [(Id, StgRhs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((\(Id
b,StgRhs
r) -> (,) Id
b (StgRhs -> (Id, StgRhs))
-> BcPrepM StgRhs -> State BcPrepM_State (Id, StgRhs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StgRhs -> BcPrepM StgRhs
bcPrepRHS StgRhs
r) ((Id, StgRhs) -> State BcPrepM_State (Id, StgRhs))
-> ((Id, StgRhs) -> (Id, StgRhs))
-> (Id, StgRhs)
-> State BcPrepM_State (Id, StgRhs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind)
                  [(Id, StgRhs)]
[(BinderP 'Vanilla, StgRhs)]
bnds

bcPrepSingleBind :: (Id, StgRhs) -> (Id, StgRhs)
-- If necessary, modify this Id and body to protect not-necessarily-lifted join points.
-- See Note [Not-necessarily-lifted join points], step 2.
bcPrepSingleBind :: (Id, StgRhs) -> (Id, StgRhs)
bcPrepSingleBind (Id
x, StgRhsClosure XRhsClosure 'Vanilla
ext CostCentreStack
cc UpdateFlag
upd_flag [BinderP 'Vanilla]
args GenStgExpr 'Vanilla
body Type
typ)
  | Id -> Bool
isNNLJoinPoint Id
x
  = ( Id -> Id
protectNNLJoinPointId Id
x
    , XRhsClosure 'Vanilla
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'Vanilla]
-> GenStgExpr 'Vanilla
-> Type
-> StgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'Vanilla
ext CostCentreStack
cc UpdateFlag
upd_flag ([Id]
[BinderP 'Vanilla]
args [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id
voidArgId]) GenStgExpr 'Vanilla
body Type
typ)
bcPrepSingleBind (Id, StgRhs)
bnd = (Id, StgRhs)
bnd

bcPrepTopLvl :: StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl :: StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl lit :: StgTopBinding
lit@StgTopStringLit{} = StgTopBinding -> BcPrepM StgTopBinding
forall a. a -> State BcPrepM_State a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StgTopBinding
lit
bcPrepTopLvl (StgTopLifted GenStgBinding 'Vanilla
bnd) = GenStgBinding 'Vanilla -> StgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> StgTopBinding)
-> State BcPrepM_State (GenStgBinding 'Vanilla)
-> BcPrepM StgTopBinding
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgBinding 'Vanilla
-> State BcPrepM_State (GenStgBinding 'Vanilla)
bcPrepBind GenStgBinding 'Vanilla
bnd

bcPrep :: UniqSupply -> [InStgTopBinding] -> [OutStgTopBinding]
bcPrep :: UniqSupply -> [StgTopBinding] -> [StgTopBinding]
bcPrep UniqSupply
us [StgTopBinding]
bnds = State BcPrepM_State [StgTopBinding]
-> BcPrepM_State -> [StgTopBinding]
forall s a. State s a -> s -> a
evalState ((StgTopBinding -> BcPrepM StgTopBinding)
-> [StgTopBinding] -> State BcPrepM_State [StgTopBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgTopBinding -> BcPrepM StgTopBinding
bcPrepTopLvl [StgTopBinding]
bnds) (UniqSupply -> BcPrepM_State
BcPrepM_State UniqSupply
us)

-- Is this Id a not-necessarily-lifted join point?
-- See Note [Not-necessarily-lifted join points], step 1
isNNLJoinPoint :: Id -> Bool
isNNLJoinPoint :: Id -> Bool
isNNLJoinPoint Id
x = Id -> Bool
isJoinId Id
x Bool -> Bool -> Bool
&& Type -> Bool
mightBeUnliftedType (Id -> Type
idType Id
x)

-- Update an Id's type to take a (# #) argument.
-- Precondition: the Id is a not-necessarily-lifted join point.
-- See Note [Not-necessarily-lifted join points]
protectNNLJoinPointId :: Id -> Id
protectNNLJoinPointId :: Id -> Id
protectNNLJoinPointId Id
x
  = Bool -> ((Type -> Type) -> Id -> Id) -> (Type -> Type) -> Id -> Id
forall a. HasCallStack => Bool -> a -> a
assert (Id -> Bool
isNNLJoinPoint Id
x )
    (Type -> Type) -> Id -> Id
updateIdTypeButNotMult (Type
unboxedUnitTy HasDebugCallStack => Type -> Type -> Type
Type -> Type -> Type
`mkVisFunTyMany`) Id
x

newUnique :: BcPrepM Unique
newUnique :: BcPrepM Unique
newUnique = (BcPrepM_State -> (Unique, BcPrepM_State)) -> BcPrepM Unique
forall s a. (s -> (a, s)) -> State s a
state ((BcPrepM_State -> (Unique, BcPrepM_State)) -> BcPrepM Unique)
-> (BcPrepM_State -> (Unique, BcPrepM_State)) -> BcPrepM Unique
forall a b. (a -> b) -> a -> b
$
  \BcPrepM_State
st -> case UniqSupply -> (Unique, UniqSupply)
takeUniqFromSupply (BcPrepM_State -> UniqSupply
prepUniqSupply BcPrepM_State
st) of
            (Unique
uniq, UniqSupply
us) -> (Unique
uniq, BcPrepM_State
st { prepUniqSupply = us })

newId :: Type -> BcPrepM Id
newId :: Type -> BcPrepM Id
newId Type
ty = do
    Unique
uniq <- BcPrepM Unique
newUnique
    Id -> BcPrepM Id
forall a. a -> State BcPrepM_State a
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> BcPrepM Id) -> Id -> BcPrepM Id
forall a b. (a -> b) -> a -> b
$ FastString -> Unique -> Type -> Type -> Id
mkSysLocal FastString
prepFS Unique
uniq Type
ManyTy Type
ty

prepFS :: FastString
prepFS :: FastString
prepFS = String -> FastString
fsLit String
"bcprep"

{-

Note [Not-necessarily-lifted join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A join point variable is essentially a goto-label: it is, for example,
never used as an argument to another function, and it is called only
in tail position. See Note [Join points] and Note [Invariants on join points],
both in GHC.Core. Because join points do not compile to true, red-blooded
variables (with, e.g., registers allocated to them), they are allowed
to be representation-polymorphic.
(See invariant #6 in Note [Invariants on join points] in GHC.Core.)

However, in this byte-code generator, join points *are* treated just as
ordinary variables. There is no check whether a binding is for a join point
or not; they are all treated uniformly. (Perhaps there is a missed optimization
opportunity here, but that is beyond the scope of my (Richard E's) Thursday.)

We thus must have *some* strategy for dealing with representation-polymorphic
and unlifted join points. Representation-polymorphic variables are generally
not allowed (though representation -polymorphic join points *are*; see
Note [Invariants on join points] in GHC.Core, point 6), and we don't wish to
evaluate unlifted join points eagerly.
The questionable join points are *not-necessarily-lifted join points*
(NNLJPs). (Not having such a strategy led to #16509, which panicked in the
isUnliftedType check in the AnnVar case of schemeE.) Here is the strategy:

1. Detect NNLJPs. This is done in isNNLJoinPoint.

2. When binding an NNLJP, add a `\ (_ :: (# #)) ->` to its RHS, and modify the
   type to tack on a `(# #) ->`.
   Note that functions are never representation-polymorphic, so this
   transformation changes an NNLJP to a non-representation-polymorphic
   join point. This is done in bcPrepSingleBind.

3. At an occurrence of an NNLJP, add an application to void# (called voidPrimId),
   being careful to note the new type of the NNLJP. This is done in the AnnVar
   case of schemeE, with help from protectNNLJoinPointId.

Here is an example. Suppose we have

  f = \(r :: RuntimeRep) (a :: TYPE r) (x :: T).
      join j :: a
           j = error @r @a "bloop"
      in case x of
           A -> j
           B -> j
           C -> error @r @a "blurp"

Our plan is to behave is if the code was

  f = \(r :: RuntimeRep) (a :: TYPE r) (x :: T).
      let j :: ((# #) -> a)
          j = \ _ -> error @r @a "bloop"
      in case x of
           A -> j void#
           B -> j void#
           C -> error @r @a "blurp"

It's a bit hacky, but it works well in practice and is local. I suspect the
Right Fix is to take advantage of join points as goto-labels.

-}