{-# LANGUAGE CPP #-}

{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998


************************************************************************

               Static Argument Transformation pass

************************************************************************

May be seen as removing invariants from loops:
Arguments of recursive functions that do not change in recursive
calls are removed from the recursion, which is done locally
and only passes the arguments which effectively change.

Example:
map = /\ ab -> \f -> \xs -> case xs of
                 []       -> []
                 (a:b) -> f a : map f b

as map is recursively called with the same argument f (unmodified)
we transform it to

map = /\ ab -> \f -> \xs -> let map' ys = case ys of
                       []     -> []
                       (a:b) -> f a : map' b
                in map' xs

Notice that for a compiler that uses lambda lifting this is
useless as map' will be transformed back to what map was.

We could possibly do the same for big lambdas, but we don't as
they will eventually be removed in later stages of the compiler,
therefore there is no penalty in keeping them.

We only apply the SAT when the number of static args is > 2. This
produces few bad cases.  See
                should_transform
in saTransform.

Here are the headline nofib results:
                  Size    Allocs   Runtime
Min             +0.0%    -13.7%    -21.4%
Max             +0.1%     +0.0%     +5.4%
Geometric Mean  +0.0%     -0.2%     -6.9%

The previous patch, to fix polymorphic floatout demand signatures, is
essential to make this work well!
-}

module GHC.Core.Opt.StaticArgs ( doStaticArgs ) where

import GHC.Prelude

import GHC.Types.Var
import GHC.Core
import GHC.Core.Utils
import GHC.Core.Type
import GHC.Core.Coercion
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Var.Env
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Types.Unique.FM
import GHC.Types.Var.Set
import GHC.Types.Unique
import GHC.Types.Unique.Set
import GHC.Utils.Outputable
import GHC.Utils.Panic

import Data.List (mapAccumL)
import GHC.Data.FastString

#include "HsVersions.h"

doStaticArgs :: UniqSupply -> CoreProgram -> CoreProgram
doStaticArgs :: UniqSupply -> CoreProgram -> CoreProgram
doStaticArgs UniqSupply
us CoreProgram
binds = (UniqSupply, CoreProgram) -> CoreProgram
forall a b. (a, b) -> b
snd ((UniqSupply, CoreProgram) -> CoreProgram)
-> (UniqSupply, CoreProgram) -> CoreProgram
forall a b. (a -> b) -> a -> b
$ (UniqSupply -> CoreBind -> (UniqSupply, CoreBind))
-> UniqSupply -> CoreProgram -> (UniqSupply, CoreProgram)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL UniqSupply -> CoreBind -> (UniqSupply, CoreBind)
sat_bind_threaded_us UniqSupply
us CoreProgram
binds
  where
    sat_bind_threaded_us :: UniqSupply -> CoreBind -> (UniqSupply, CoreBind)
sat_bind_threaded_us UniqSupply
us CoreBind
bind =
        let (UniqSupply
us1, UniqSupply
us2) = UniqSupply -> (UniqSupply, UniqSupply)
splitUniqSupply UniqSupply
us
        in (UniqSupply
us1, (CoreBind, IdSATInfo) -> CoreBind
forall a b. (a, b) -> a
fst ((CoreBind, IdSATInfo) -> CoreBind)
-> (CoreBind, IdSATInfo) -> CoreBind
forall a b. (a -> b) -> a -> b
$ UniqSupply -> SatM (CoreBind, IdSATInfo) -> (CoreBind, IdSATInfo)
forall a. UniqSupply -> SatM a -> a
runSAT UniqSupply
us2 (CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
satBind CoreBind
bind IdSet
forall a. UniqSet a
emptyUniqSet))

-- We don't bother to SAT recursive groups since it can lead
-- to massive code expansion: see Andre Santos' thesis for details.
-- This means we only apply the actual SAT to Rec groups of one element,
-- but we want to recurse into the others anyway to discover other binds
satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
satBind (NonRec Id
binder Expr Id
expr) IdSet
interesting_ids = do
    (Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
expr IdSet
interesting_ids
    (CoreBind, IdSATInfo) -> SatM (CoreBind, IdSATInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder Expr Id
expr', Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
expr_app IdSATInfo
sat_info_expr)
satBind (Rec [(Id
binder, Expr Id
rhs)]) IdSet
interesting_ids = do
    let interesting_ids' :: IdSet
interesting_ids' = IdSet
interesting_ids IdSet -> Id -> IdSet
forall a. Uniquable a => UniqSet a -> a -> UniqSet a
`addOneToUniqSet` Id
binder
        ([Id]
rhs_binders, Expr Id
rhs_body) = Expr Id -> ([Id], Expr Id)
forall b. Expr b -> ([b], Expr b)
collectBinders Expr Id
rhs
    (Expr Id
rhs_body', IdSATInfo
sat_info_rhs_body) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo)
satTopLevelExpr Expr Id
rhs_body IdSet
interesting_ids'
    let sat_info_rhs_from_args :: IdSATInfo
sat_info_rhs_from_args = Id -> SATInfo -> IdSATInfo
forall a. Id -> a -> VarEnv a
unitVarEnv Id
binder ([Id] -> SATInfo
bindersToSATInfo [Id]
rhs_binders)
        sat_info_rhs' :: IdSATInfo
sat_info_rhs' = IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo IdSATInfo
sat_info_rhs_from_args IdSATInfo
sat_info_rhs_body

        shadowing :: Bool
shadowing = Id
binder Id -> IdSet -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` IdSet
interesting_ids
        sat_info_rhs'' :: IdSATInfo
sat_info_rhs'' = if Bool
shadowing
                        then IdSATInfo
sat_info_rhs' IdSATInfo -> Id -> IdSATInfo
forall key elt.
Uniquable key =>
UniqFM key elt -> key -> UniqFM key elt
`delFromUFM` Id
binder -- For safety
                        else IdSATInfo
sat_info_rhs'

    CoreBind
bind' <- Id -> Maybe SATInfo -> [Id] -> Expr Id -> SatM CoreBind
saTransformMaybe Id
binder (IdSATInfo -> Id -> Maybe SATInfo
forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM IdSATInfo
sat_info_rhs' Id
binder)
                              [Id]
rhs_binders Expr Id
rhs_body'
    (CoreBind, IdSATInfo) -> SatM (CoreBind, IdSATInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBind
bind', IdSATInfo
sat_info_rhs'')
satBind (Rec [(Id, Expr Id)]
pairs) IdSet
interesting_ids = do
    let ([Id]
binders, [Expr Id]
rhss) = [(Id, Expr Id)] -> ([Id], [Expr Id])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Id, Expr Id)]
pairs
    [(Expr Id, IdSATInfo)]
rhss_SATed <- (Expr Id -> SatM (Expr Id, IdSATInfo))
-> [Expr Id] -> UniqSM [(Expr Id, IdSATInfo)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Expr Id
e -> Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo)
satTopLevelExpr Expr Id
e IdSet
interesting_ids) [Expr Id]
rhss
    let ([Expr Id]
rhss', [IdSATInfo]
sat_info_rhss') = [(Expr Id, IdSATInfo)] -> ([Expr Id], [IdSATInfo])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Expr Id, IdSATInfo)]
rhss_SATed
    (CoreBind, IdSATInfo) -> SatM (CoreBind, IdSATInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (String -> [Id] -> [Expr Id] -> [(Id, Expr Id)]
forall a b. String -> [a] -> [b] -> [(a, b)]
zipEqual String
"satBind" [Id]
binders [Expr Id]
rhss'), [IdSATInfo] -> IdSATInfo
mergeIdSATInfos [IdSATInfo]
sat_info_rhss')

data App = VarApp Id | TypeApp Type | CoApp Coercion
data Staticness a = Static a | NotStatic

type IdAppInfo = (Id, SATInfo)

type SATInfo = [Staticness App]
type IdSATInfo = IdEnv SATInfo
emptyIdSATInfo :: IdSATInfo
emptyIdSATInfo :: IdSATInfo
emptyIdSATInfo = IdSATInfo
forall key elt. UniqFM key elt
emptyUFM

{-
pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (Map.toList id_sat_info))
  where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
-}

pprSATInfo :: SATInfo -> SDoc
pprSATInfo :: SATInfo -> SDoc
pprSATInfo SATInfo
staticness = [SDoc] -> SDoc
hcat ([SDoc] -> SDoc) -> [SDoc] -> SDoc
forall a b. (a -> b) -> a -> b
$ (Staticness App -> SDoc) -> SATInfo -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map Staticness App -> SDoc
pprStaticness SATInfo
staticness

pprStaticness :: Staticness App -> SDoc
pprStaticness :: Staticness App -> SDoc
pprStaticness (Static (VarApp Id
_))  = String -> SDoc
text String
"SV"
pprStaticness (Static (TypeApp Type
_)) = String -> SDoc
text String
"ST"
pprStaticness (Static (CoApp Coercion
_))   = String -> SDoc
text String
"SC"
pprStaticness Staticness App
NotStatic            = String -> SDoc
text String
"NS"


mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
mergeSATInfo SATInfo
l SATInfo
r = (Staticness App -> Staticness App -> Staticness App)
-> SATInfo -> SATInfo -> SATInfo
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Staticness App -> Staticness App -> Staticness App
mergeSA SATInfo
l SATInfo
r
  where
    mergeSA :: Staticness App -> Staticness App -> Staticness App
mergeSA Staticness App
NotStatic Staticness App
_ = Staticness App
forall a. Staticness a
NotStatic
    mergeSA Staticness App
_ Staticness App
NotStatic = Staticness App
forall a. Staticness a
NotStatic
    mergeSA (Static (VarApp Id
v)) (Static (VarApp Id
v'))
      | Id
v Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
v'   = App -> Staticness App
forall a. a -> Staticness a
Static (Id -> App
VarApp Id
v)
      | Bool
otherwise = Staticness App
forall a. Staticness a
NotStatic
    mergeSA (Static (TypeApp Type
t)) (Static (TypeApp Type
t'))
      | Type
t Type -> Type -> Bool
`eqType` Type
t' = App -> Staticness App
forall a. a -> Staticness a
Static (Type -> App
TypeApp Type
t)
      | Bool
otherwise     = Staticness App
forall a. Staticness a
NotStatic
    mergeSA (Static (CoApp Coercion
c)) (Static (CoApp Coercion
c'))
      | Coercion
c Coercion -> Coercion -> Bool
`eqCoercion` Coercion
c' = App -> Staticness App
forall a. a -> Staticness a
Static (Coercion -> App
CoApp Coercion
c)
      | Bool
otherwise             = Staticness App
forall a. Staticness a
NotStatic
    mergeSA Staticness App
_ Staticness App
_  = String -> SDoc -> Staticness App
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"mergeSATInfo" (SDoc -> Staticness App) -> SDoc -> Staticness App
forall a b. (a -> b) -> a -> b
$
                          String -> SDoc
text String
"Left:"
                       SDoc -> SDoc -> SDoc
<> SATInfo -> SDoc
pprSATInfo SATInfo
l SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
", "
                       SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
"Right:"
                       SDoc -> SDoc -> SDoc
<> SATInfo -> SDoc
pprSATInfo SATInfo
r

mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo = (SATInfo -> SATInfo -> SATInfo)
-> IdSATInfo -> IdSATInfo -> IdSATInfo
forall elt key.
(elt -> elt -> elt)
-> UniqFM key elt -> UniqFM key elt -> UniqFM key elt
plusUFM_C SATInfo -> SATInfo -> SATInfo
mergeSATInfo

mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
mergeIdSATInfos = (IdSATInfo -> IdSATInfo -> IdSATInfo)
-> IdSATInfo -> [IdSATInfo] -> IdSATInfo
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo IdSATInfo
emptyIdSATInfo

bindersToSATInfo :: [Id] -> SATInfo
bindersToSATInfo :: [Id] -> SATInfo
bindersToSATInfo [Id]
vs = (Id -> Staticness App) -> [Id] -> SATInfo
forall a b. (a -> b) -> [a] -> [b]
map (App -> Staticness App
forall a. a -> Staticness a
Static (App -> Staticness App) -> (Id -> App) -> Id -> Staticness App
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> App
binderToApp) [Id]
vs
    where binderToApp :: Id -> App
binderToApp Id
v | Id -> Bool
isId Id
v    = Id -> App
VarApp Id
v
                        | Id -> Bool
isTyVar Id
v = Type -> App
TypeApp (Type -> App) -> Type -> App
forall a b. (a -> b) -> a -> b
$ Id -> Type
mkTyVarTy Id
v
                        | Bool
otherwise = Coercion -> App
CoApp (Coercion -> App) -> Coercion -> App
forall a b. (a -> b) -> a -> b
$ Id -> Coercion
mkCoVarCo Id
v

finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
Nothing IdSATInfo
id_sat_info = IdSATInfo
id_sat_info
finalizeApp (Just (Id
v, SATInfo
sat_info')) IdSATInfo
id_sat_info =
    let sat_info'' :: SATInfo
sat_info'' = case IdSATInfo -> Id -> Maybe SATInfo
forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM IdSATInfo
id_sat_info Id
v of
                        Maybe SATInfo
Nothing -> SATInfo
sat_info'
                        Just SATInfo
sat_info -> SATInfo -> SATInfo -> SATInfo
mergeSATInfo SATInfo
sat_info SATInfo
sat_info'
    in IdSATInfo -> Id -> SATInfo -> IdSATInfo
forall a. VarEnv a -> Id -> a -> VarEnv a
extendVarEnv IdSATInfo
id_sat_info Id
v SATInfo
sat_info''

satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
satTopLevelExpr :: Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo)
satTopLevelExpr Expr Id
expr IdSet
interesting_ids = do
    (Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
expr IdSet
interesting_ids
    (Expr Id, IdSATInfo) -> SatM (Expr Id, IdSATInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id
expr', Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
expr_app IdSATInfo
sat_info_expr)

satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
satExpr :: Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr var :: Expr Id
var@(Var Id
v) IdSet
interesting_ids = do
    let app_info :: Maybe IdAppInfo
app_info = if Id
v Id -> IdSet -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` IdSet
interesting_ids
                   then IdAppInfo -> Maybe IdAppInfo
forall a. a -> Maybe a
Just (Id
v, [])
                   else Maybe IdAppInfo
forall a. Maybe a
Nothing
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id
var, IdSATInfo
emptyIdSATInfo, Maybe IdAppInfo
app_info)

satExpr lit :: Expr Id
lit@(Lit Literal
_) IdSet
_ =
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id
lit, IdSATInfo
emptyIdSATInfo, Maybe IdAppInfo
forall a. Maybe a
Nothing)

satExpr (Lam Id
binders Expr Id
body) IdSet
interesting_ids = do
    (Expr Id
body', IdSATInfo
sat_info, Maybe IdAppInfo
this_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
body IdSet
interesting_ids
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
binders Expr Id
body', Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
this_app IdSATInfo
sat_info, Maybe IdAppInfo
forall a. Maybe a
Nothing)

satExpr (App Expr Id
fn Expr Id
arg) IdSet
interesting_ids = do
    (Expr Id
fn', IdSATInfo
sat_info_fn, Maybe IdAppInfo
fn_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
fn IdSet
interesting_ids
    let satRemainder :: Maybe IdAppInfo -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainder = Expr Id
-> IdSATInfo
-> Maybe IdAppInfo
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
boring Expr Id
fn' IdSATInfo
sat_info_fn
    case Maybe IdAppInfo
fn_app of
        Maybe IdAppInfo
Nothing -> Maybe IdAppInfo -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainder Maybe IdAppInfo
forall a. Maybe a
Nothing
        Just (Id
fn_id, SATInfo
fn_app_info) ->
            -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
            let satRemainderWithStaticness :: Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainderWithStaticness Staticness App
arg_staticness = Maybe IdAppInfo -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainder (Maybe IdAppInfo -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo))
-> Maybe IdAppInfo -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall a b. (a -> b) -> a -> b
$ IdAppInfo -> Maybe IdAppInfo
forall a. a -> Maybe a
Just (Id
fn_id, SATInfo
fn_app_info SATInfo -> SATInfo -> SATInfo
forall a. [a] -> [a] -> [a]
++ [Staticness App
arg_staticness])
            in case Expr Id
arg of
                Type Type
t     -> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainderWithStaticness (Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo))
-> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall a b. (a -> b) -> a -> b
$ App -> Staticness App
forall a. a -> Staticness a
Static (Type -> App
TypeApp Type
t)
                Coercion Coercion
c -> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainderWithStaticness (Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo))
-> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall a b. (a -> b) -> a -> b
$ App -> Staticness App
forall a. a -> Staticness a
Static (Coercion -> App
CoApp Coercion
c)
                Var Id
v      -> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainderWithStaticness (Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo))
-> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall a b. (a -> b) -> a -> b
$ App -> Staticness App
forall a. a -> Staticness a
Static (Id -> App
VarApp Id
v)
                Expr Id
_          -> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satRemainderWithStaticness (Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo))
-> Staticness App -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall a b. (a -> b) -> a -> b
$ Staticness App
forall a. Staticness a
NotStatic
  where
    boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
    boring :: Expr Id
-> IdSATInfo
-> Maybe IdAppInfo
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
boring Expr Id
fn' IdSATInfo
sat_info_fn Maybe IdAppInfo
app_info =
        do (Expr Id
arg', IdSATInfo
sat_info_arg, Maybe IdAppInfo
arg_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
arg IdSet
interesting_ids
           let sat_info_arg' :: IdSATInfo
sat_info_arg' = Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
arg_app IdSATInfo
sat_info_arg
               sat_info :: IdSATInfo
sat_info = IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo IdSATInfo
sat_info_fn IdSATInfo
sat_info_arg'
           (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App Expr Id
fn' Expr Id
arg', IdSATInfo
sat_info, Maybe IdAppInfo
app_info)

satExpr (Case Expr Id
expr Id
bndr Type
ty [Alt Id]
alts) IdSet
interesting_ids = do
    (Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
expr IdSet
interesting_ids
    let sat_info_expr' :: IdSATInfo
sat_info_expr' = Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
finalizeApp Maybe IdAppInfo
expr_app IdSATInfo
sat_info_expr

    [(Alt Id, IdSATInfo)]
zipped_alts' <- (Alt Id -> UniqSM (Alt Id, IdSATInfo))
-> [Alt Id] -> UniqSM [(Alt Id, IdSATInfo)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Alt Id -> UniqSM (Alt Id, IdSATInfo)
satAlt [Alt Id]
alts
    let ([Alt Id]
alts', [IdSATInfo]
sat_infos_alts) = [(Alt Id, IdSATInfo)] -> ([Alt Id], [IdSATInfo])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Alt Id, IdSATInfo)]
zipped_alts'
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id -> Id -> Type -> [Alt Id] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr Id
expr' Id
bndr Type
ty [Alt Id]
alts', IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo IdSATInfo
sat_info_expr' ([IdSATInfo] -> IdSATInfo
mergeIdSATInfos [IdSATInfo]
sat_infos_alts), Maybe IdAppInfo
forall a. Maybe a
Nothing)
  where
    satAlt :: Alt Id -> UniqSM (Alt Id, IdSATInfo)
satAlt (Alt AltCon
con [Id]
bndrs Expr Id
expr) = do
        (Expr Id
expr', IdSATInfo
sat_info_expr) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo)
satTopLevelExpr Expr Id
expr IdSet
interesting_ids
        (Alt Id, IdSATInfo) -> UniqSM (Alt Id, IdSATInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (AltCon -> [Id] -> Expr Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Id]
bndrs Expr Id
expr', IdSATInfo
sat_info_expr)

satExpr (Let CoreBind
bind Expr Id
body) IdSet
interesting_ids = do
    (Expr Id
body', IdSATInfo
sat_info_body, Maybe IdAppInfo
body_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
body IdSet
interesting_ids
    (CoreBind
bind', IdSATInfo
sat_info_bind) <- CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
satBind CoreBind
bind IdSet
interesting_ids
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
bind' Expr Id
body', IdSATInfo -> IdSATInfo -> IdSATInfo
mergeIdSATInfo IdSATInfo
sat_info_body IdSATInfo
sat_info_bind, Maybe IdAppInfo
body_app)

satExpr (Tick CoreTickish
tickish Expr Id
expr) IdSet
interesting_ids = do
    (Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
expr IdSet
interesting_ids
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app)

satExpr ty :: Expr Id
ty@(Type Type
_) IdSet
_ =
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id
ty, IdSATInfo
emptyIdSATInfo, Maybe IdAppInfo
forall a. Maybe a
Nothing)

satExpr co :: Expr Id
co@(Coercion Coercion
_) IdSet
_ =
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id
co, IdSATInfo
emptyIdSATInfo, Maybe IdAppInfo
forall a. Maybe a
Nothing)

satExpr (Cast Expr Id
expr Coercion
coercion) IdSet
interesting_ids = do
    (Expr Id
expr', IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app) <- Expr Id -> IdSet -> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
satExpr Expr Id
expr IdSet
interesting_ids
    (Expr Id, IdSATInfo, Maybe IdAppInfo)
-> SatM (Expr Id, IdSATInfo, Maybe IdAppInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast Expr Id
expr' Coercion
coercion, IdSATInfo
sat_info_expr, Maybe IdAppInfo
expr_app)

{-
************************************************************************

                Static Argument Transformation Monad

************************************************************************
-}

type SatM result = UniqSM result

runSAT :: UniqSupply -> SatM a -> a
runSAT :: forall a. UniqSupply -> SatM a -> a
runSAT = UniqSupply -> UniqSM a -> a
forall a. UniqSupply -> SatM a -> a
initUs_

newUnique :: SatM Unique
newUnique :: SatM Unique
newUnique = SatM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM

{-
************************************************************************

                Static Argument Transformation Monad

************************************************************************

To do the transformation, the game plan is to:

1. Create a small nonrecursive RHS that takes the
   original arguments to the function but discards
   the ones that are static and makes a call to the
   SATed version with the remainder. We intend that
   this will be inlined later, removing the overhead

2. Bind this nonrecursive RHS over the original body
   WITH THE SAME UNIQUE as the original body so that
   any recursive calls to the original now go via
   the small wrapper

3. Rebind the original function to a new one which contains
   our SATed function and just makes a call to it:
   we call the thing making this call the local body

Example: transform this

    map :: forall a b. (a->b) -> [a] -> [b]
    map = /\ab. \(f:a->b) (as:[a]) -> body[map]
to
    map :: forall a b. (a->b) -> [a] -> [b]
    map = /\ab. \(f:a->b) (as:[a]) ->
         letrec map' :: [a] -> [b]
                    -- The "worker function
                map' = \(as:[a]) ->
                         let map :: forall a' b'. (a -> b) -> [a] -> [b]
                                -- The "shadow function
                             map = /\a'b'. \(f':(a->b) (as:[a]).
                                   map' as
                         in body[map]
         in map' as

Note [Shadow binding]
~~~~~~~~~~~~~~~~~~~~~
The calls to the inner map inside body[map] should get inlined
by the local re-binding of 'map'.  We call this the "shadow binding".

But we can't use the original binder 'map' unchanged, because
it might be exported, in which case the shadow binding won't be
discarded as dead code after it is inlined.

So we use a hack: we make a new SysLocal binder with the *same* unique
as binder.  (Another alternative would be to reset the export flag.)

Note [Binder type capture]
~~~~~~~~~~~~~~~~~~~~~~~~~~
Notice that in the inner map (the "shadow function"), the static arguments
are discarded -- it's as if they were underscores.  Instead, mentions
of these arguments (notably in the types of dynamic arguments) are bound
by the *outer* lambdas of the main function.  So we must make up fresh
names for the static arguments so that they do not capture variables
mentioned in the types of dynamic args.

In the map example, the shadow function must clone the static type
argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
is bound by the outer forall.  We clone f' too for consistency, but
that doesn't matter either way because static Id arguments aren't
mentioned in the shadow binding at all.

If we don't we get something like this:

[Exported]
[Arity 3]
GHC.Base.until =
  \ (@ a_aiK)
    (p_a6T :: a_aiK -> GHC.Types.Bool)
    (f_a6V :: a_aiK -> a_aiK)
    (x_a6X :: a_aiK) ->
    letrec {
      sat_worker_s1aU :: a_aiK -> a_aiK
      []
      sat_worker_s1aU =
        \ (x_a6X :: a_aiK) ->
          let {
            sat_shadow_r17 :: forall a_a3O.
                              (a_a3O -> GHC.Types.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
            []
            sat_shadow_r17 =
              \ (@ a_aiK)
                (p_a6T :: a_aiK -> GHC.Types.Bool)
                (f_a6V :: a_aiK -> a_aiK)
                (x_a6X :: a_aiK) ->
                sat_worker_s1aU x_a6X } in
          case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
            GHC.Types.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
            GHC.Types.True -> x_a6X
          }; } in
    sat_worker_s1aU x_a6X

Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK
type argument. This is bad because it means the application sat_worker_s1aU x_a6X
is not well typed.
-}

saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> Expr Id -> SatM CoreBind
saTransformMaybe Id
binder Maybe SATInfo
maybe_arg_staticness [Id]
rhs_binders Expr Id
rhs_body
  | Just SATInfo
arg_staticness <- Maybe SATInfo
maybe_arg_staticness
  , SATInfo -> Bool
should_transform SATInfo
arg_staticness
  = Id -> SATInfo -> [Id] -> Expr Id -> SatM CoreBind
saTransform Id
binder SATInfo
arg_staticness [Id]
rhs_binders Expr Id
rhs_body
  | Bool
otherwise
  = CoreBind -> SatM CoreBind
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id
binder, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
rhs_binders Expr Id
rhs_body)])
  where
    should_transform :: SATInfo -> Bool
should_transform SATInfo
staticness = Int
n_static_args Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 -- THIS IS THE DECISION POINT
      where
        n_static_args :: Int
n_static_args = (Staticness App -> Bool) -> SATInfo -> Int
forall a. (a -> Bool) -> [a] -> Int
count Staticness App -> Bool
isStaticValue SATInfo
staticness

saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
saTransform :: Id -> SATInfo -> [Id] -> Expr Id -> SatM CoreBind
saTransform Id
binder SATInfo
arg_staticness [Id]
rhs_binders Expr Id
rhs_body
  = do  { [Id]
shadow_lam_bndrs <- ((Id, Staticness App) -> UniqSM Id)
-> [(Id, Staticness App)] -> UniqSM [Id]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Id, Staticness App) -> UniqSM Id
forall {a}. (Id, Staticness a) -> UniqSM Id
clone [(Id, Staticness App)]
binders_w_staticness
        ; Unique
uniq             <- SatM Unique
newUnique
        ; CoreBind -> SatM CoreBind
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (Unique -> [Id] -> Expr Id
mk_new_rhs Unique
uniq [Id]
shadow_lam_bndrs)) }
  where
    -- Running example: foldr
    -- foldr \alpha \beta c n xs = e, for some e
    -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
    -- rhs_binders = [\alpha, \beta, c, n, xs]
    -- rhs_body = e

    binders_w_staticness :: [(Id, Staticness App)]
binders_w_staticness = [Id]
rhs_binders [Id] -> SATInfo -> [(Id, Staticness App)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` (SATInfo
arg_staticness SATInfo -> SATInfo -> SATInfo
forall a. [a] -> [a] -> [a]
++ Staticness App -> SATInfo
forall a. a -> [a]
repeat Staticness App
forall a. Staticness a
NotStatic)
                                        -- Any extra args are assumed NotStatic

    non_static_args :: [Var]
            -- non_static_args = [xs]
            -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
    non_static_args :: [Id]
non_static_args = [Id
v | (Id
v, Staticness App
NotStatic) <- [(Id, Staticness App)]
binders_w_staticness]

    clone :: (Id, Staticness a) -> UniqSM Id
clone (Id
bndr, Staticness a
NotStatic) = Id -> UniqSM Id
forall (m :: * -> *) a. Monad m => a -> m a
return Id
bndr
    clone (Id
bndr, Staticness a
_        ) = do { Unique
uniq <- SatM Unique
newUnique
                                 ; Id -> UniqSM Id
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Unique -> Id
setVarUnique Id
bndr Unique
uniq) }

    -- new_rhs = \alpha beta c n xs ->
    --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs ->
    --                                       sat_worker xs
    --                                   in e
    --           in sat_worker xs
    mk_new_rhs :: Unique -> [Id] -> Expr Id
mk_new_rhs Unique
uniq [Id]
shadow_lam_bndrs
        = [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
rhs_binders (Expr Id -> Expr Id) -> Expr Id -> Expr Id
forall a b. (a -> b) -> a -> b
$
          CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id
rec_body_bndr, Expr Id
rec_body)])
          Expr Id
local_body
        where
          local_body :: Expr Id
local_body = Expr Id -> [Id] -> Expr Id
forall b. Expr b -> [Id] -> Expr b
mkVarApps (Id -> Expr Id
forall b. Id -> Expr b
Var Id
rec_body_bndr) [Id]
non_static_args

          rec_body :: Expr Id
rec_body = [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
non_static_args (Expr Id -> Expr Id) -> Expr Id -> Expr Id
forall a b. (a -> b) -> a -> b
$
                     CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let (Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
shadow_bndr Expr Id
shadow_rhs) Expr Id
rhs_body

            -- See Note [Binder type capture]
          shadow_rhs :: Expr Id
shadow_rhs = [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
shadow_lam_bndrs Expr Id
local_body
            -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs

          rec_body_bndr :: Id
rec_body_bndr = FastString -> Unique -> Type -> Type -> Id
mkSysLocal (String -> FastString
fsLit String
"sat_worker") Unique
uniq Type
Many (Expr Id -> Type
exprType Expr Id
rec_body)
            -- rec_body_bndr = sat_worker

            -- See Note [Shadow binding]; make a SysLocal
          shadow_bndr :: Id
shadow_bndr = FastString -> Unique -> Type -> Type -> Id
mkSysLocal (OccName -> FastString
occNameFS (Id -> OccName
forall a. NamedThing a => a -> OccName
getOccName Id
binder))
                                   (Id -> Unique
idUnique Id
binder)
                                   Type
Many
                                   (Expr Id -> Type
exprType Expr Id
shadow_rhs)

isStaticValue :: Staticness App -> Bool
isStaticValue :: Staticness App -> Bool
isStaticValue (Static (VarApp Id
_)) = Bool
True
isStaticValue Staticness App
_                   = Bool
False