{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM'
-- monad.
module GHC.Stg.Lift.Monad (
    decomposeStgBinding, mkStgBinding,
    Env (..),
    -- * #floats# Handling floats
    -- $floats
    FloatLang (..), collectFloats, -- Exported just for the docs
    -- * Transformation monad
    LiftM, runLiftM,
    -- ** Adding bindings
    startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
    -- ** Substitution and binders
    withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
    -- ** Occurrences
    substOcc, isLifted, formerFreeVars, liftedIdsExpander
  ) where

#include "HsVersions.h"

import GHC.Prelude

import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Driver.Session
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity

import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )

-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec pairs) = (Recursive, pairs)
decomposeStgBinding (StgNonRec bndr rhs) = (NonRecursive, [(bndr, rhs)])

mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding Recursive = StgRec
mkStgBinding NonRecursive = uncurry StgNonRec . head

-- | Environment threaded around in a scoped, @Reader@-like fashion.
data Env
  = Env
  { e_dflags     :: !DynFlags
  -- ^ Read-only.
  , e_subst      :: !Subst
  -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId',
  -- because shadowing might make a closure's free variables unavailable at its
  -- call sites. Consider:
  -- @
  --    let f y = x + y in let x = 4 in f x
  -- @
  -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't
  -- available at its call site.
  , e_expansions :: !(IdEnv DIdSet)
  -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because
  -- they are bound at the top-level. Every occurrence must supply the formerly
  -- free variables of the lifted 'Id', so they in turn become free variables of
  -- the call sites. This environment tracks this expansion from lifted 'Id's to
  -- their free variables.
  --
  -- 'InId's to 'OutId's.
  --
  -- Invariant: 'Id's not present in this map won't be substituted.
  }

emptyEnv :: DynFlags -> Env
emptyEnv dflags = Env dflags emptySubst emptyVarEnv


-- Note [Handling floats]
-- ~~~~~~~~~~~~~~~~~~~~~~
-- $floats
-- Consider the following expression:
--
-- @
--     f x =
--       let g y = ... f y ...
--       in g x
-- @
--
-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@
-- binding above the binding for @f@:
--
-- @
--     g f y = ... f y ...
--     f x = g f x
-- @
--
-- But this very unnecessarily turns a known call to @f@ into an unknown one, in
-- addition to complicating matters for the analysis.
-- Instead, we'd really like to put both functions in the same recursive group,
-- thereby preserving the known call:
--
-- @
--     Rec {
--       g y = ... f y ...
--       f x = g x
--     }
-- @
--
-- But we don't want this to happen for just /any/ binding. That would create
-- possibly huge recursive groups in the process, calling for an occurrence
-- analyser on STG.
-- So, we need to track when we lift a binding out of a recursive RHS and add
-- the binding to the same recursive group as the enclosing recursive binding
-- (which must have either already been at the top-level or decided to be
-- lifted itself in order to preserve the known call).
--
-- This is done by expressing this kind of nesting structure as a 'Writer' over
-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to
-- 'collectFloats'.
-- API-wise, the analysis will not need to know about the whole 'FloatLang'
-- business and will just manipulate it indirectly through actions in 'LiftM'.

-- | We need to detect when we are lifting something out of the RHS of a
-- recursive binding (c.f. "GHC.Stg.Lift.Monad#floats"), in which case that
-- binding needs to be added to the same top-level recursive group. This
-- requires we detect a certain nesting structure, which is encoded by
-- 'StartBindingGroup' and 'EndBindingGroup'.
--
-- Although 'collectFloats' will only ever care if the current binding to be
-- lifted (through 'LiftedBinding') will occur inside such a binding group or
-- not, e.g. doesn't care about the nesting level as long as its greater than 0.
data FloatLang
  = StartBindingGroup
  | EndBindingGroup
  | PlainTopBinding OutStgTopBinding
  | LiftedBinding OutStgBinding

instance Outputable FloatLang where
  ppr StartBindingGroup = char '('
  ppr EndBindingGroup = char ')'
  ppr (PlainTopBinding StgTopStringLit{}) = text "<str>"
  ppr (PlainTopBinding (StgTopLifted b)) = ppr (LiftedBinding b)
  ppr (LiftedBinding bind) = (if isRec rec then char 'r' else char 'n') <+> ppr (map fst pairs)
    where
      (rec, pairs) = decomposeStgBinding bind

-- | Flattens an expression in @['FloatLang']@ into an STG program, see "GHC.Stg.Lift.Monad#floats".
-- Important pre-conditions: The nesting of opening 'StartBindinGroup's and
-- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding
-- group has at least one recursive binding inside. Otherwise there's no point
-- in announcing the binding group in the first place and an @ASSERT@ will
-- trigger.
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats = go (0 :: Int) []
  where
    go 0 [] [] = []
    go _ _ [] = pprPanic "collectFloats" (text "unterminated group")
    go n binds (f:rest) = case f of
      StartBindingGroup -> go (n+1) binds rest
      EndBindingGroup
        | n == 0 -> pprPanic "collectFloats" (text "no group to end")
        | n == 1 -> StgTopLifted (merge_binds binds) : go 0 [] rest
        | otherwise -> go (n-1) binds rest
      PlainTopBinding top_bind
        | n == 0 -> top_bind : go n binds rest
        | otherwise -> pprPanic "collectFloats" (text "plain top binding inside group")
      LiftedBinding bind
        | n == 0 -> StgTopLifted (rm_cccs bind) : go n binds rest
        | otherwise -> go n (bind:binds) rest

    map_rhss f = uncurry mkStgBinding . second (map (second f)) . decomposeStgBinding
    rm_cccs = map_rhss removeRhsCCCS
    merge_binds binds = ASSERT( any is_rec binds )
                        StgRec (concatMap (snd . decomposeStgBinding . rm_cccs) binds)
    is_rec StgRec{} = True
    is_rec _ = False

-- | Omitting this makes for strange closure allocation schemes that crash the
-- GC.
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure ext ccs upd bndrs body)
  | isCurrentCCS ccs
  = StgRhsClosure ext dontCareCCS upd bndrs body
removeRhsCCCS (StgRhsCon ccs con args)
  | isCurrentCCS ccs
  = StgRhsCon dontCareCCS con args
removeRhsCCCS rhs = rhs

-- | The analysis monad consists of the following 'RWST' components:
--
--     * 'Env': Reader-like context. Contains a substitution, info about how
--       how lifted identifiers are to be expanded into applications and details
--       such as 'DynFlags'.
--
--     * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program.
--
--     * No pure state component
--
--     * But wrapping around 'UniqSM' for generating fresh lifted binders.
--       (The @uniqAway@ approach could give the same name to two different
--       lifted binders, so this is necessary.)
newtype LiftM a
  = LiftM { unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
  deriving (Functor, Applicative, Monad)

instance HasDynFlags LiftM where
  getDynFlags = LiftM (RWS.asks e_dflags)

instance MonadUnique LiftM where
  getUniqueSupplyM = LiftM (lift getUniqueSupplyM)
  getUniqueM = LiftM (lift getUniqueM)
  getUniquesM = LiftM (lift getUniquesM)

runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM dflags us (LiftM m) = collectFloats (fromOL floats)
  where
    (_, _, floats) = initUs_ us (runRWST m (emptyEnv dflags) ())

-- | Writes a plain 'StgTopStringLit' to the output.
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit id = LiftM . RWS.tell . unitOL . PlainTopBinding . StgTopStringLit id

-- | Starts a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
startBindingGroup :: LiftM ()
startBindingGroup = LiftM $ RWS.tell $ unitOL $ StartBindingGroup

-- | Ends a recursive binding group. See "GHC.Stg.Lift.Monad#floats" and 'collectFloats'.
endBindingGroup :: LiftM ()
endBindingGroup = LiftM $ RWS.tell $ unitOL $ EndBindingGroup

-- | Lifts a binding to top-level. Depending on whether it's declared inside
-- a recursive RHS (see "GHC.Stg.Lift.Monad#floats" and 'collectFloats'), this might be added to
-- an existing recursive top-level binding group.
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding = LiftM . RWS.tell . unitOL . LiftedBinding

-- | Takes a binder and a continuation which is called with the substituted
-- binder. The continuation will be evaluated in a 'LiftM' context in which that
-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After
-- the continuation finishes, the new binding won't be in scope anymore.
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr bndr inner = LiftM $ do
  subst <- RWS.asks e_subst
  let (bndr', subst') = substBndr bndr subst
  RWS.local (\e -> e { e_subst = subst' }) (unwrapLiftM (inner bndr'))

-- | See 'withSubstBndr'.
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = runContT . traverse (ContT . withSubstBndr)

-- | Similarly to 'withSubstBndr', this function takes a set of variables to
-- abstract over, the binder to lift (and generate a fresh, substituted name
-- for) and a continuation in which that fresh, lifted binder is in scope.
--
-- It takes care of all the details involved with copying and adjusting the
-- binder and fresh name generation.
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr abs_ids bndr inner = do
  uniq <- getUniqueM
  let str = "$l" ++ occNameString (getOccName bndr)
  let ty = mkLamTypes (dVarSetElems abs_ids) (idType bndr)
  let bndr'
        -- See Note [transferPolyIdInfo] in GHC.Types.Id. We need to do this at least
        -- for arity information.
        = transferPolyIdInfo bndr (dVarSetElems abs_ids)
        . mkSysLocal (mkFastString str) uniq Many
        $ ty
  LiftM $ RWS.local
    (\e -> e
      { e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e
      , e_expansions = extendVarEnv (e_expansions e) bndr abs_ids
      })
    (unwrapLiftM (inner bndr'))

-- | See 'withLiftedBndr'.
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs abs_ids = runContT . traverse (ContT . withLiftedBndr abs_ids)

-- | Substitutes a binder /occurrence/, which was brought in scope earlier by
-- 'withSubstBndr' \/ 'withLiftedBndr'.
substOcc :: Id -> LiftM Id
substOcc id = LiftM (RWS.asks (lookupIdSubst id . e_subst))

-- | Whether the given binding was decided to be lambda lifted.
isLifted :: InId -> LiftM Bool
isLifted bndr = LiftM (RWS.asks (elemVarEnv bndr . e_expansions))

-- | Returns an empty list for a binding that was not lifted and the list of all
-- local variables the binding abstracts over (so, exactly the additional
-- arguments at adjusted call sites) otherwise.
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars f = LiftM $ do
  expansions <- RWS.asks e_expansions
  pure $ case lookupVarEnv expansions f of
    Nothing -> []
    Just fvs -> dVarSetElems fvs

-- | Creates an /expander function/ for the current set of lifted binders.
-- This expander function will replace any 'InId' by their corresponding 'OutId'
-- and, in addition, will expand any lifted binders by the former free variables
-- it abstracts over.
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = LiftM $ do
  expansions <- RWS.asks e_expansions
  subst <- RWS.asks e_subst
  -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope"
  -- warnings generated by 'lookupIdSubst' due to local bindings within RHS.
  -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in
  -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much
  -- trouble.
  let go set fv = case lookupVarEnv expansions fv of
        Nothing -> extendDVarSet set (noWarnLookupIdSubst fv subst) -- Not lifted
        Just fvs' -> unionDVarSet set fvs'
  let expander fvs = foldl' go emptyDVarSet (dVarSetElems fvs)
  pure expander