{-
(c) The GRASP/AQUA Project, Glasgow University, 1993-1998

\section[SimplStg]{Driver for simplifying @STG@ programs}
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module GHC.Stg.Pipeline ( stg2stg ) where

#include "HsVersions.h"

import GHC.Prelude

import GHC.Stg.Syntax

import GHC.Stg.Lint     ( lintStgTopBindings )
import GHC.Stg.Stats    ( showStgStats )
import GHC.Stg.DepAnal  ( depSortStgPgm )
import GHC.Stg.Unarise  ( unarise )
import GHC.Stg.CSE      ( stgCse )
import GHC.Stg.Lift     ( stgLiftLams )
import GHC.Unit.Module ( Module )

import GHC.Driver.Session
import GHC.Utils.Error
import GHC.Types.Unique.Supply
import GHC.Utils.Outputable
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State.Strict

newtype StgM a = StgM { forall a. StgM a -> StateT Char IO a
_unStgM :: StateT Char IO a }
  deriving ((forall a b. (a -> b) -> StgM a -> StgM b)
-> (forall a b. a -> StgM b -> StgM a) -> Functor StgM
forall a b. a -> StgM b -> StgM a
forall a b. (a -> b) -> StgM a -> StgM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> StgM b -> StgM a
$c<$ :: forall a b. a -> StgM b -> StgM a
fmap :: forall a b. (a -> b) -> StgM a -> StgM b
$cfmap :: forall a b. (a -> b) -> StgM a -> StgM b
Functor, Functor StgM
Functor StgM
-> (forall a. a -> StgM a)
-> (forall a b. StgM (a -> b) -> StgM a -> StgM b)
-> (forall a b c. (a -> b -> c) -> StgM a -> StgM b -> StgM c)
-> (forall a b. StgM a -> StgM b -> StgM b)
-> (forall a b. StgM a -> StgM b -> StgM a)
-> Applicative StgM
forall a. a -> StgM a
forall a b. StgM a -> StgM b -> StgM a
forall a b. StgM a -> StgM b -> StgM b
forall a b. StgM (a -> b) -> StgM a -> StgM b
forall a b c. (a -> b -> c) -> StgM a -> StgM b -> StgM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. StgM a -> StgM b -> StgM a
$c<* :: forall a b. StgM a -> StgM b -> StgM a
*> :: forall a b. StgM a -> StgM b -> StgM b
$c*> :: forall a b. StgM a -> StgM b -> StgM b
liftA2 :: forall a b c. (a -> b -> c) -> StgM a -> StgM b -> StgM c
$cliftA2 :: forall a b c. (a -> b -> c) -> StgM a -> StgM b -> StgM c
<*> :: forall a b. StgM (a -> b) -> StgM a -> StgM b
$c<*> :: forall a b. StgM (a -> b) -> StgM a -> StgM b
pure :: forall a. a -> StgM a
$cpure :: forall a. a -> StgM a
Applicative, Applicative StgM
Applicative StgM
-> (forall a b. StgM a -> (a -> StgM b) -> StgM b)
-> (forall a b. StgM a -> StgM b -> StgM b)
-> (forall a. a -> StgM a)
-> Monad StgM
forall a. a -> StgM a
forall a b. StgM a -> StgM b -> StgM b
forall a b. StgM a -> (a -> StgM b) -> StgM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> StgM a
$creturn :: forall a. a -> StgM a
>> :: forall a b. StgM a -> StgM b -> StgM b
$c>> :: forall a b. StgM a -> StgM b -> StgM b
>>= :: forall a b. StgM a -> (a -> StgM b) -> StgM b
$c>>= :: forall a b. StgM a -> (a -> StgM b) -> StgM b
Monad, Monad StgM
Monad StgM -> (forall a. IO a -> StgM a) -> MonadIO StgM
forall a. IO a -> StgM a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: forall a. IO a -> StgM a
$cliftIO :: forall a. IO a -> StgM a
MonadIO)

instance MonadUnique StgM where
  getUniqueSupplyM :: StgM UniqSupply
getUniqueSupplyM = StateT Char IO UniqSupply -> StgM UniqSupply
forall a. StateT Char IO a -> StgM a
StgM (StateT Char IO UniqSupply -> StgM UniqSupply)
-> StateT Char IO UniqSupply -> StgM UniqSupply
forall a b. (a -> b) -> a -> b
$ do { Char
mask <- StateT Char IO Char
forall (m :: * -> *) s. Monad m => StateT s m s
get
                               ; IO UniqSupply -> StateT Char IO UniqSupply
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO UniqSupply -> StateT Char IO UniqSupply)
-> IO UniqSupply -> StateT Char IO UniqSupply
forall a b. (a -> b) -> a -> b
$! Char -> IO UniqSupply
mkSplitUniqSupply Char
mask}
  getUniqueM :: StgM Unique
getUniqueM = StateT Char IO Unique -> StgM Unique
forall a. StateT Char IO a -> StgM a
StgM (StateT Char IO Unique -> StgM Unique)
-> StateT Char IO Unique -> StgM Unique
forall a b. (a -> b) -> a -> b
$ do { Char
mask <- StateT Char IO Char
forall (m :: * -> *) s. Monad m => StateT s m s
get
                         ; IO Unique -> StateT Char IO Unique
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Unique -> StateT Char IO Unique)
-> IO Unique -> StateT Char IO Unique
forall a b. (a -> b) -> a -> b
$! Char -> IO Unique
uniqFromMask Char
mask}

runStgM :: Char -> StgM a -> IO a
runStgM :: forall a. Char -> StgM a -> IO a
runStgM Char
mask (StgM StateT Char IO a
m) = StateT Char IO a -> Char -> IO a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT Char IO a
m Char
mask

stg2stg :: DynFlags                  -- includes spec of what stg-to-stg passes to do
        -> Module                    -- module being compiled
        -> [StgTopBinding]           -- input program
        -> IO [StgTopBinding]        -- output program

stg2stg :: DynFlags -> Module -> [StgTopBinding] -> IO [StgTopBinding]
stg2stg DynFlags
dflags Module
this_mod [StgTopBinding]
binds
  = do  { DumpFlag -> String -> [StgTopBinding] -> IO ()
dump_when DumpFlag
Opt_D_dump_stg String
"STG:" [StgTopBinding]
binds
        ; DynFlags -> String -> IO ()
showPass DynFlags
dflags String
"Stg2Stg"
        -- Do the main business!
        ; [StgTopBinding]
binds' <- Char -> StgM [StgTopBinding] -> IO [StgTopBinding]
forall a. Char -> StgM a -> IO a
runStgM Char
'g' (StgM [StgTopBinding] -> IO [StgTopBinding])
-> StgM [StgTopBinding] -> IO [StgTopBinding]
forall a b. (a -> b) -> a -> b
$
            ([StgTopBinding] -> StgToDo -> StgM [StgTopBinding])
-> [StgTopBinding] -> [StgToDo] -> StgM [StgTopBinding]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM [StgTopBinding] -> StgToDo -> StgM [StgTopBinding]
do_stg_pass [StgTopBinding]
binds (DynFlags -> [StgToDo]
getStgToDo DynFlags
dflags)

          -- Dependency sort the program as last thing. The program needs to be
          -- in dependency order for the SRT algorithm to work (see
          -- CmmBuildInfoTables, which also includes a detailed description of
          -- the algorithm), and we don't guarantee that the program is already
          -- sorted at this point. #16192 is for simplifier not preserving
          -- dependency order. We also don't guarantee that StgLiftLams will
          -- preserve the order or only create minimal recursive groups, so a
          -- sorting pass is necessary.
        ; let binds_sorted :: [StgTopBinding]
binds_sorted = Module -> [StgTopBinding] -> [StgTopBinding]
depSortStgPgm Module
this_mod [StgTopBinding]
binds'
        ; [StgTopBinding] -> IO [StgTopBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [StgTopBinding]
binds_sorted
   }

  where
    stg_linter :: Bool -> String -> [StgTopBinding] -> IO ()
stg_linter Bool
unarised
      | GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
Opt_DoStgLinting DynFlags
dflags
      = DynFlags -> Module -> Bool -> String -> [StgTopBinding] -> IO ()
forall (a :: StgPass).
(OutputablePass a, BinderP a ~ Id) =>
DynFlags
-> Module -> Bool -> String -> [GenStgTopBinding a] -> IO ()
lintStgTopBindings DynFlags
dflags Module
this_mod Bool
unarised
      | Bool
otherwise
      = \ String
_whodunnit [StgTopBinding]
_binds -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -------------------------------------------
    do_stg_pass :: [StgTopBinding] -> StgToDo -> StgM [StgTopBinding]
    do_stg_pass :: [StgTopBinding] -> StgToDo -> StgM [StgTopBinding]
do_stg_pass [StgTopBinding]
binds StgToDo
to_do
      = case StgToDo
to_do of
          StgToDo
StgDoNothing ->
            [StgTopBinding] -> StgM [StgTopBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [StgTopBinding]
binds

          StgToDo
StgStats ->
            String -> StgM [StgTopBinding] -> StgM [StgTopBinding]
forall a. String -> a -> a
trace ([StgTopBinding] -> String
showStgStats [StgTopBinding]
binds) ([StgTopBinding] -> StgM [StgTopBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [StgTopBinding]
binds)

          StgToDo
StgCSE -> do
            let binds' :: [StgTopBinding]
binds' = {-# SCC "StgCse" #-} [StgTopBinding] -> [StgTopBinding]
stgCse [StgTopBinding]
binds
            String -> [StgTopBinding] -> StgM [StgTopBinding]
end_pass String
"StgCse" [StgTopBinding]
binds'

          StgToDo
StgLiftLams -> do
            UniqSupply
us <- StgM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM
            let binds' :: [StgTopBinding]
binds' = {-# SCC "StgLiftLams" #-} DynFlags -> UniqSupply -> [StgTopBinding] -> [StgTopBinding]
stgLiftLams DynFlags
dflags UniqSupply
us [StgTopBinding]
binds
            String -> [StgTopBinding] -> StgM [StgTopBinding]
end_pass String
"StgLiftLams" [StgTopBinding]
binds'

          StgToDo
StgUnarise -> do
            UniqSupply
us <- StgM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM
            IO () -> StgM ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Bool -> String -> [StgTopBinding] -> IO ()
stg_linter Bool
False String
"Pre-unarise" [StgTopBinding]
binds)
            let binds' :: [StgTopBinding]
binds' = UniqSupply -> [StgTopBinding] -> [StgTopBinding]
unarise UniqSupply
us [StgTopBinding]
binds
            IO () -> StgM ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (DumpFlag -> String -> [StgTopBinding] -> IO ()
dump_when DumpFlag
Opt_D_dump_stg_unarised String
"Unarised STG:" [StgTopBinding]
binds')
            IO () -> StgM ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Bool -> String -> [StgTopBinding] -> IO ()
stg_linter Bool
True String
"Unarise" [StgTopBinding]
binds')
            [StgTopBinding] -> StgM [StgTopBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [StgTopBinding]
binds'

    opts :: StgPprOpts
opts = DynFlags -> StgPprOpts
initStgPprOpts DynFlags
dflags
    dump_when :: DumpFlag -> String -> [StgTopBinding] -> IO ()
dump_when DumpFlag
flag String
header [StgTopBinding]
binds
      = DynFlags -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
flag String
header DumpFormat
FormatSTG (StgPprOpts -> [StgTopBinding] -> SDoc
pprStgTopBindings StgPprOpts
opts [StgTopBinding]
binds)

    end_pass :: String -> [StgTopBinding] -> StgM [StgTopBinding]
end_pass String
what [StgTopBinding]
binds2
      = IO [StgTopBinding] -> StgM [StgTopBinding]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [StgTopBinding] -> StgM [StgTopBinding])
-> IO [StgTopBinding] -> StgM [StgTopBinding]
forall a b. (a -> b) -> a -> b
$ do -- report verbosely, if required
          DynFlags -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_verbose_stg2stg String
what
            DumpFormat
FormatSTG ([SDoc] -> SDoc
vcat ((StgTopBinding -> SDoc) -> [StgTopBinding] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map (StgPprOpts -> StgTopBinding -> SDoc
pprStgTopBinding StgPprOpts
opts) [StgTopBinding]
binds2))
          Bool -> String -> [StgTopBinding] -> IO ()
stg_linter Bool
False String
what [StgTopBinding]
binds2
          [StgTopBinding] -> IO [StgTopBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [StgTopBinding]
binds2

-- -----------------------------------------------------------------------------
-- StgToDo:  abstraction of stg-to-stg passes to run.

-- | Optional Stg-to-Stg passes.
data StgToDo
  = StgCSE
  -- ^ Common subexpression elimination
  | StgLiftLams
  -- ^ Lambda lifting closure variables, trading stack/register allocation for
  -- heap allocation
  | StgStats
  | StgUnarise
  -- ^ Mandatory unarise pass, desugaring unboxed tuple and sum binders
  | StgDoNothing
  -- ^ Useful for building up 'getStgToDo'
  deriving StgToDo -> StgToDo -> Bool
(StgToDo -> StgToDo -> Bool)
-> (StgToDo -> StgToDo -> Bool) -> Eq StgToDo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StgToDo -> StgToDo -> Bool
$c/= :: StgToDo -> StgToDo -> Bool
== :: StgToDo -> StgToDo -> Bool
$c== :: StgToDo -> StgToDo -> Bool
Eq

-- | Which Stg-to-Stg passes to run. Depends on flags, ways etc.
getStgToDo :: DynFlags -> [StgToDo]
getStgToDo :: DynFlags -> [StgToDo]
getStgToDo DynFlags
dflags =
  (StgToDo -> Bool) -> [StgToDo] -> [StgToDo]
forall a. (a -> Bool) -> [a] -> [a]
filter (StgToDo -> StgToDo -> Bool
forall a. Eq a => a -> a -> Bool
/= StgToDo
StgDoNothing)
    [ StgToDo -> StgToDo
forall {a}. a -> a
mandatory StgToDo
StgUnarise
    -- Important that unarisation comes first
    -- See Note [StgCse after unarisation] in GHC.Stg.CSE
    , GeneralFlag -> StgToDo -> StgToDo
optional GeneralFlag
Opt_StgCSE StgToDo
StgCSE
    , GeneralFlag -> StgToDo -> StgToDo
optional GeneralFlag
Opt_StgLiftLams StgToDo
StgLiftLams
    , GeneralFlag -> StgToDo -> StgToDo
optional GeneralFlag
Opt_StgStats StgToDo
StgStats
    ] where
      optional :: GeneralFlag -> StgToDo -> StgToDo
optional GeneralFlag
opt = Bool -> StgToDo -> StgToDo
runWhen (GeneralFlag -> DynFlags -> Bool
gopt GeneralFlag
opt DynFlags
dflags)
      mandatory :: a -> a
mandatory = a -> a
forall {a}. a -> a
id

runWhen :: Bool -> StgToDo -> StgToDo
runWhen :: Bool -> StgToDo -> StgToDo
runWhen Bool
True StgToDo
todo = StgToDo
todo
runWhen Bool
_    StgToDo
_    = StgToDo
StgDoNothing