{-# LANGUAGE GADTs #-}

-----------------------------------------------------------------------------
--
-- Code generator utilities; mostly monadic
--
-- (c) The University of Glasgow 2004-2006
--
-----------------------------------------------------------------------------

module GHC.StgToCmm.CgUtils (
        fixStgRegisters,
        baseRegOffset,
        get_Regtable_addr_from_offset,
        regTableOffset,
        get_GlobalReg_addr,
  ) where

import GHC.Prelude

import GHC.Platform.Regs
import GHC.Cmm
import GHC.Cmm.Dataflow.Block
import GHC.Cmm.Dataflow.Graph
import GHC.Cmm.Utils
import GHC.Cmm.CLabel
import GHC.Driver.Session
import GHC.Utils.Outputable

-- -----------------------------------------------------------------------------
-- Information about global registers

baseRegOffset :: DynFlags -> GlobalReg -> Int

baseRegOffset :: DynFlags -> GlobalReg -> Int
baseRegOffset DynFlags
dflags (VanillaReg Int
1 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR1 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
2 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR2 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
3 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR3 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
4 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR4 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
5 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR5 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
6 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR6 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
7 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR7 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
8 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR8 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
9 VGcPtr
_)    = DynFlags -> Int
oFFSET_StgRegTable_rR9 DynFlags
dflags
baseRegOffset DynFlags
dflags (VanillaReg Int
10 VGcPtr
_)   = DynFlags -> Int
oFFSET_StgRegTable_rR10 DynFlags
dflags
baseRegOffset DynFlags
_      (VanillaReg Int
n VGcPtr
_)    = String -> Int
forall a. String -> a
panic (String
"Registers above R10 are not supported (tried to use R" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags (FloatReg  Int
1)       = DynFlags -> Int
oFFSET_StgRegTable_rF1 DynFlags
dflags
baseRegOffset DynFlags
dflags (FloatReg  Int
2)       = DynFlags -> Int
oFFSET_StgRegTable_rF2 DynFlags
dflags
baseRegOffset DynFlags
dflags (FloatReg  Int
3)       = DynFlags -> Int
oFFSET_StgRegTable_rF3 DynFlags
dflags
baseRegOffset DynFlags
dflags (FloatReg  Int
4)       = DynFlags -> Int
oFFSET_StgRegTable_rF4 DynFlags
dflags
baseRegOffset DynFlags
dflags (FloatReg  Int
5)       = DynFlags -> Int
oFFSET_StgRegTable_rF5 DynFlags
dflags
baseRegOffset DynFlags
dflags (FloatReg  Int
6)       = DynFlags -> Int
oFFSET_StgRegTable_rF6 DynFlags
dflags
baseRegOffset DynFlags
_      (FloatReg  Int
n)       = String -> Int
forall a. String -> a
panic (String
"Registers above F6 are not supported (tried to use F" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags (DoubleReg Int
1)       = DynFlags -> Int
oFFSET_StgRegTable_rD1 DynFlags
dflags
baseRegOffset DynFlags
dflags (DoubleReg Int
2)       = DynFlags -> Int
oFFSET_StgRegTable_rD2 DynFlags
dflags
baseRegOffset DynFlags
dflags (DoubleReg Int
3)       = DynFlags -> Int
oFFSET_StgRegTable_rD3 DynFlags
dflags
baseRegOffset DynFlags
dflags (DoubleReg Int
4)       = DynFlags -> Int
oFFSET_StgRegTable_rD4 DynFlags
dflags
baseRegOffset DynFlags
dflags (DoubleReg Int
5)       = DynFlags -> Int
oFFSET_StgRegTable_rD5 DynFlags
dflags
baseRegOffset DynFlags
dflags (DoubleReg Int
6)       = DynFlags -> Int
oFFSET_StgRegTable_rD6 DynFlags
dflags
baseRegOffset DynFlags
_      (DoubleReg Int
n)       = String -> Int
forall a. String -> a
panic (String
"Registers above D6 are not supported (tried to use D" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags (XmmReg Int
1)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM1 DynFlags
dflags
baseRegOffset DynFlags
dflags (XmmReg Int
2)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM2 DynFlags
dflags
baseRegOffset DynFlags
dflags (XmmReg Int
3)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM3 DynFlags
dflags
baseRegOffset DynFlags
dflags (XmmReg Int
4)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM4 DynFlags
dflags
baseRegOffset DynFlags
dflags (XmmReg Int
5)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM5 DynFlags
dflags
baseRegOffset DynFlags
dflags (XmmReg Int
6)          = DynFlags -> Int
oFFSET_StgRegTable_rXMM6 DynFlags
dflags
baseRegOffset DynFlags
_      (XmmReg Int
n)          = String -> Int
forall a. String -> a
panic (String
"Registers above XMM6 are not supported (tried to use XMM" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags (YmmReg Int
1)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM1 DynFlags
dflags
baseRegOffset DynFlags
dflags (YmmReg Int
2)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM2 DynFlags
dflags
baseRegOffset DynFlags
dflags (YmmReg Int
3)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM3 DynFlags
dflags
baseRegOffset DynFlags
dflags (YmmReg Int
4)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM4 DynFlags
dflags
baseRegOffset DynFlags
dflags (YmmReg Int
5)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM5 DynFlags
dflags
baseRegOffset DynFlags
dflags (YmmReg Int
6)          = DynFlags -> Int
oFFSET_StgRegTable_rYMM6 DynFlags
dflags
baseRegOffset DynFlags
_      (YmmReg Int
n)          = String -> Int
forall a. String -> a
panic (String
"Registers above YMM6 are not supported (tried to use YMM" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags (ZmmReg Int
1)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM1 DynFlags
dflags
baseRegOffset DynFlags
dflags (ZmmReg Int
2)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM2 DynFlags
dflags
baseRegOffset DynFlags
dflags (ZmmReg Int
3)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM3 DynFlags
dflags
baseRegOffset DynFlags
dflags (ZmmReg Int
4)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM4 DynFlags
dflags
baseRegOffset DynFlags
dflags (ZmmReg Int
5)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM5 DynFlags
dflags
baseRegOffset DynFlags
dflags (ZmmReg Int
6)          = DynFlags -> Int
oFFSET_StgRegTable_rZMM6 DynFlags
dflags
baseRegOffset DynFlags
_      (ZmmReg Int
n)          = String -> Int
forall a. String -> a
panic (String
"Registers above ZMM6 are not supported (tried to use ZMM" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags GlobalReg
Sp                  = DynFlags -> Int
oFFSET_StgRegTable_rSp DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
SpLim               = DynFlags -> Int
oFFSET_StgRegTable_rSpLim DynFlags
dflags
baseRegOffset DynFlags
dflags (LongReg Int
1)         = DynFlags -> Int
oFFSET_StgRegTable_rL1 DynFlags
dflags
baseRegOffset DynFlags
_      (LongReg Int
n)         = String -> Int
forall a. String -> a
panic (String
"Registers above L1 are not supported (tried to use L" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")")
baseRegOffset DynFlags
dflags GlobalReg
Hp                  = DynFlags -> Int
oFFSET_StgRegTable_rHp DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
HpLim               = DynFlags -> Int
oFFSET_StgRegTable_rHpLim DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
CCCS                = DynFlags -> Int
oFFSET_StgRegTable_rCCCS DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
CurrentTSO          = DynFlags -> Int
oFFSET_StgRegTable_rCurrentTSO DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
CurrentNursery      = DynFlags -> Int
oFFSET_StgRegTable_rCurrentNursery DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
HpAlloc             = DynFlags -> Int
oFFSET_StgRegTable_rHpAlloc DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
EagerBlackholeInfo  = DynFlags -> Int
oFFSET_stgEagerBlackholeInfo DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
GCEnter1            = DynFlags -> Int
oFFSET_stgGCEnter1 DynFlags
dflags
baseRegOffset DynFlags
dflags GlobalReg
GCFun               = DynFlags -> Int
oFFSET_stgGCFun DynFlags
dflags
baseRegOffset DynFlags
_      GlobalReg
BaseReg             = String -> Int
forall a. String -> a
panic String
"CgUtils.baseRegOffset:BaseReg"
baseRegOffset DynFlags
_      GlobalReg
PicBaseReg          = String -> Int
forall a. String -> a
panic String
"CgUtils.baseRegOffset:PicBaseReg"
baseRegOffset DynFlags
_      GlobalReg
MachSp              = String -> Int
forall a. String -> a
panic String
"CgUtils.baseRegOffset:MachSp"
baseRegOffset DynFlags
_      GlobalReg
UnwindReturnReg     = String -> Int
forall a. String -> a
panic String
"CgUtils.baseRegOffset:UnwindReturnReg"


-- -----------------------------------------------------------------------------
--
-- STG/Cmm GlobalReg
--
-- -----------------------------------------------------------------------------

-- | We map STG registers onto appropriate CmmExprs.  Either they map
-- to real machine registers or stored as offsets from BaseReg.  Given
-- a GlobalReg, get_GlobalReg_addr always produces the
-- register table address for it.
get_GlobalReg_addr :: DynFlags -> GlobalReg -> CmmExpr
get_GlobalReg_addr :: DynFlags -> GlobalReg -> CmmExpr
get_GlobalReg_addr DynFlags
dflags GlobalReg
BaseReg = DynFlags -> Int -> CmmExpr
regTableOffset DynFlags
dflags Int
0
get_GlobalReg_addr DynFlags
dflags GlobalReg
mid
    = DynFlags -> Int -> CmmExpr
get_Regtable_addr_from_offset DynFlags
dflags (DynFlags -> GlobalReg -> Int
baseRegOffset DynFlags
dflags GlobalReg
mid)

-- Calculate a literal representing an offset into the register table.
-- Used when we don't have an actual BaseReg to offset from.
regTableOffset :: DynFlags -> Int -> CmmExpr
regTableOffset :: DynFlags -> Int -> CmmExpr
regTableOffset DynFlags
dflags Int
n =
  CmmLit -> CmmExpr
CmmLit (CLabel -> Int -> CmmLit
CmmLabelOff CLabel
mkMainCapabilityLabel (DynFlags -> Int
oFFSET_Capability_r DynFlags
dflags Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n))

get_Regtable_addr_from_offset :: DynFlags -> Int -> CmmExpr
get_Regtable_addr_from_offset :: DynFlags -> Int -> CmmExpr
get_Regtable_addr_from_offset DynFlags
dflags Int
offset =
    if Platform -> Bool
haveRegBase (DynFlags -> Platform
targetPlatform DynFlags
dflags)
    then CmmReg -> Int -> CmmExpr
cmmRegOff CmmReg
baseReg Int
offset
    else DynFlags -> Int -> CmmExpr
regTableOffset DynFlags
dflags Int
offset

-- | Fixup global registers so that they assign to locations within the
-- RegTable if they aren't pinned for the current target.
fixStgRegisters :: DynFlags -> RawCmmDecl -> RawCmmDecl
fixStgRegisters :: DynFlags -> RawCmmDecl -> RawCmmDecl
fixStgRegisters DynFlags
_ top :: RawCmmDecl
top@(CmmData Section
_ RawCmmStatics
_) = RawCmmDecl
top

fixStgRegisters DynFlags
dflags (CmmProc LabelMap RawCmmStatics
info CLabel
lbl [GlobalReg]
live CmmGraph
graph) =
  let graph' :: CmmGraph
graph' = (Graph CmmNode C C -> Graph CmmNode C C) -> CmmGraph -> CmmGraph
forall (n :: Extensibility -> Extensibility -> *)
       (n' :: Extensibility -> Extensibility -> *).
(Graph n C C -> Graph n' C C) -> GenCmmGraph n -> GenCmmGraph n'
modifyGraph ((forall (e1 :: Extensibility) (x1 :: Extensibility).
 Block CmmNode e1 x1 -> Block CmmNode e1 x1)
-> Graph CmmNode C C -> Graph CmmNode C C
forall (block :: (Extensibility -> Extensibility -> *)
                 -> Extensibility -> Extensibility -> *)
       (n :: Extensibility -> Extensibility -> *)
       (block' :: (Extensibility -> Extensibility -> *)
                  -> Extensibility -> Extensibility -> *)
       (n' :: Extensibility -> Extensibility -> *) (e :: Extensibility)
       (x :: Extensibility).
(forall (e1 :: Extensibility) (x1 :: Extensibility).
 block n e1 x1 -> block' n' e1 x1)
-> Graph' block n e x -> Graph' block' n' e x
mapGraphBlocks (DynFlags -> Block CmmNode e1 x1 -> Block CmmNode e1 x1
forall (e :: Extensibility) (x :: Extensibility).
DynFlags -> Block CmmNode e x -> Block CmmNode e x
fixStgRegBlock DynFlags
dflags)) CmmGraph
graph
  in LabelMap RawCmmStatics
-> CLabel -> [GlobalReg] -> CmmGraph -> RawCmmDecl
forall d h g. h -> CLabel -> [GlobalReg] -> g -> GenCmmDecl d h g
CmmProc LabelMap RawCmmStatics
info CLabel
lbl [GlobalReg]
live CmmGraph
graph'

fixStgRegBlock :: DynFlags -> Block CmmNode e x -> Block CmmNode e x
fixStgRegBlock :: forall (e :: Extensibility) (x :: Extensibility).
DynFlags -> Block CmmNode e x -> Block CmmNode e x
fixStgRegBlock DynFlags
dflags Block CmmNode e x
block = (forall (e1 :: Extensibility) (x1 :: Extensibility).
 CmmNode e1 x1 -> CmmNode e1 x1)
-> Block CmmNode e x -> Block CmmNode e x
forall (n :: Extensibility -> Extensibility -> *)
       (n' :: Extensibility -> Extensibility -> *) (e :: Extensibility)
       (x :: Extensibility).
(forall (e1 :: Extensibility) (x1 :: Extensibility).
 n e1 x1 -> n' e1 x1)
-> Block n e x -> Block n' e x
mapBlock (DynFlags -> CmmNode e1 x1 -> CmmNode e1 x1
forall (e :: Extensibility) (x :: Extensibility).
DynFlags -> CmmNode e x -> CmmNode e x
fixStgRegStmt DynFlags
dflags) Block CmmNode e x
block

fixStgRegStmt :: DynFlags -> CmmNode e x -> CmmNode e x
fixStgRegStmt :: forall (e :: Extensibility) (x :: Extensibility).
DynFlags -> CmmNode e x -> CmmNode e x
fixStgRegStmt DynFlags
dflags CmmNode e x
stmt = CmmNode e x -> CmmNode e x
fixAssign (CmmNode e x -> CmmNode e x) -> CmmNode e x -> CmmNode e x
forall a b. (a -> b) -> a -> b
$ (CmmExpr -> CmmExpr) -> CmmNode e x -> CmmNode e x
forall (e :: Extensibility) (x :: Extensibility).
(CmmExpr -> CmmExpr) -> CmmNode e x -> CmmNode e x
mapExpDeep CmmExpr -> CmmExpr
fixExpr CmmNode e x
stmt
  where
    platform :: Platform
platform = DynFlags -> Platform
targetPlatform DynFlags
dflags

    fixAssign :: CmmNode e x -> CmmNode e x
fixAssign CmmNode e x
stmt =
      case CmmNode e x
stmt of
        CmmAssign (CmmGlobal GlobalReg
reg) CmmExpr
src
          -- MachSp isn't an STG register; it's merely here for tracking unwind
          -- information
          | GlobalReg
reg GlobalReg -> GlobalReg -> Bool
forall a. Eq a => a -> a -> Bool
== GlobalReg
MachSp -> CmmNode e x
stmt
          | Bool
otherwise ->
            let baseAddr :: CmmExpr
baseAddr = DynFlags -> GlobalReg -> CmmExpr
get_GlobalReg_addr DynFlags
dflags GlobalReg
reg
            in case GlobalReg
reg GlobalReg -> [GlobalReg] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Platform -> [GlobalReg]
activeStgRegs Platform
platform of
                Bool
True  -> CmmReg -> CmmExpr -> CmmNode 'Open 'Open
CmmAssign (GlobalReg -> CmmReg
CmmGlobal GlobalReg
reg) CmmExpr
src
                Bool
False -> CmmExpr -> CmmExpr -> CmmNode 'Open 'Open
CmmStore CmmExpr
baseAddr CmmExpr
src
        CmmNode e x
other_stmt -> CmmNode e x
other_stmt

    fixExpr :: CmmExpr -> CmmExpr
fixExpr CmmExpr
expr = case CmmExpr
expr of
        -- MachSp isn't an STG; it's merely here for tracking unwind information
        CmmReg (CmmGlobal GlobalReg
MachSp) -> CmmExpr
expr
        CmmReg (CmmGlobal GlobalReg
reg) ->
            -- Replace register leaves with appropriate StixTrees for
            -- the given target.  MagicIds which map to a reg on this
            -- arch are left unchanged.  For the rest, BaseReg is taken
            -- to mean the address of the reg table in MainCapability,
            -- and for all others we generate an indirection to its
            -- location in the register table.
            case GlobalReg
reg GlobalReg -> [GlobalReg] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Platform -> [GlobalReg]
activeStgRegs Platform
platform of
                Bool
True  -> CmmExpr
expr
                Bool
False ->
                    let baseAddr :: CmmExpr
baseAddr = DynFlags -> GlobalReg -> CmmExpr
get_GlobalReg_addr DynFlags
dflags GlobalReg
reg
                    in case GlobalReg
reg of
                        GlobalReg
BaseReg -> CmmExpr
baseAddr
                        GlobalReg
_other  -> CmmExpr -> CmmType -> CmmExpr
CmmLoad CmmExpr
baseAddr (Platform -> GlobalReg -> CmmType
globalRegType Platform
platform GlobalReg
reg)

        CmmRegOff (CmmGlobal GlobalReg
reg) Int
offset ->
            -- RegOf leaves are just a shorthand form. If the reg maps
            -- to a real reg, we keep the shorthand, otherwise, we just
            -- expand it and defer to the above code.
            case GlobalReg
reg GlobalReg -> [GlobalReg] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Platform -> [GlobalReg]
activeStgRegs Platform
platform of
                Bool
True  -> CmmExpr
expr
                Bool
False -> MachOp -> [CmmExpr] -> CmmExpr
CmmMachOp (Width -> MachOp
MO_Add (Platform -> Width
wordWidth Platform
platform)) [
                                    CmmExpr -> CmmExpr
fixExpr (CmmReg -> CmmExpr
CmmReg (GlobalReg -> CmmReg
CmmGlobal GlobalReg
reg)),
                                    CmmLit -> CmmExpr
CmmLit (Integer -> Width -> CmmLit
CmmInt (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
offset)
                                                   (Platform -> Width
wordWidth Platform
platform))]

        CmmExpr
other_expr -> CmmExpr
other_expr