{-# LANGUAGE GADTs #-}
module GHC.Cmm.Switch.Implement
  ( cmmImplementSwitchPlans
  )
where

import GHC.Prelude

import GHC.Driver.Backend
import GHC.Platform
import GHC.Cmm.Dataflow.Block
import GHC.Cmm.BlockId
import GHC.Cmm
import GHC.Cmm.Utils
import GHC.Cmm.Switch
import GHC.Types.Unique.Supply
import GHC.Utils.Monad (concatMapM)

--
-- This module replaces Switch statements as generated by the Stg -> Cmm
-- transformation, which might be huge and sparse and hence unsuitable for
-- assembly code, by proper constructs (if-then-else trees, dense jump tables).
--
-- The actual, abstract strategy is determined by createSwitchPlan in
-- GHC.Cmm.Switch and returned as a SwitchPlan; here is just the implementation in
-- terms of Cmm code. See Note [Cmm Switches, the general plan] in GHC.Cmm.Switch.
--
-- This division into different modules is both to clearly separate concerns,
-- but also because createSwitchPlan needs access to the constructors of
-- SwitchTargets, a data type exported abstractly by GHC.Cmm.Switch.
--

-- | Traverses the 'CmmGraph', making sure that 'CmmSwitch' are suitable for
-- code generation.
cmmImplementSwitchPlans :: Backend -> Platform -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans :: Backend -> Platform -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans Backend
backend Platform
platform CmmGraph
g
    -- Switch generation done by backend (LLVM/C)
    | Backend -> Bool
backendSupportsSwitch Backend
backend = forall (m :: * -> *) a. Monad m => a -> m a
return CmmGraph
g
    | Bool
otherwise = do
    [CmmBlock]
blocks' <- forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches Platform
platform) (CmmGraph -> [CmmBlock]
toBlockList CmmGraph
g)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ BlockId -> [CmmBlock] -> CmmGraph
ofBlockList (forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> BlockId
g_entry CmmGraph
g) [CmmBlock]
blocks'

visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches Platform
platform CmmBlock
block
  | (entry :: CmmNode C O
entry@(CmmEntry BlockId
_ CmmTickScope
scope), Block CmmNode O O
middle, CmmSwitch CmmExpr
vanillaExpr SwitchTargets
ids) <- forall (n :: Extensibility -> Extensibility -> *).
Block n C C -> (n C O, Block n O O, n O C)
blockSplit CmmBlock
block
  = do
    let plan :: SwitchPlan
plan = SwitchTargets -> SwitchPlan
createSwitchPlan SwitchTargets
ids
    -- See Note [Floating switch expressions]
    (Block CmmNode O O
assignSimple, CmmExpr
simpleExpr) <- Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr Platform
platform CmmExpr
vanillaExpr

    (Block CmmNode O C
newTail, [CmmBlock]
newBlocks) <- Platform
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan Platform
platform CmmTickScope
scope CmmExpr
simpleExpr SwitchPlan
plan

    let block' :: CmmBlock
block' = CmmNode C O
entry forall (n :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O O
middle forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility) (x :: Extensibility).
Block n e O -> Block n O x -> Block n e x
`blockAppend` Block CmmNode O O
assignSimple forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility) (x :: Extensibility).
Block n e O -> Block n O x -> Block n e x
`blockAppend` Block CmmNode O C
newTail

    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CmmBlock
block' forall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks

  | Bool
otherwise
  = forall (m :: * -> *) a. Monad m => a -> m a
return [CmmBlock
block]

-- Note [Floating switch expressions]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

-- When we translate a sparse switch into a search tree we would like
-- to compute the value we compare against only once.

-- For this purpose we assign the switch expression to a local register
-- and then use this register when constructing the actual binary tree.

-- This is important as the expression could contain expensive code like
-- memory loads or divisions which we REALLY don't want to duplicate.

-- This happened in parts of the handwritten RTS Cmm code. See also #16933

-- See Note [Floating switch expressions]
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr Platform
_        reg :: CmmExpr
reg@(CmmReg {})  = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock, CmmExpr
reg)
floatSwitchExpr Platform
platform CmmExpr
expr             = do
  (CmmNode O O
assign, CmmExpr
expr') <- Platform -> CmmExpr -> Unique -> (CmmNode O O, CmmExpr)
cmmMkAssign Platform
platform CmmExpr
expr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Extensibility -> Extensibility -> *).
n O O -> Block n O O
BMiddle CmmNode O O
assign, CmmExpr
expr')


-- Implementing a switch plan (returning a tail block)
implementSwitchPlan :: Platform -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan :: Platform
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan Platform
platform CmmTickScope
scope CmmExpr
expr = SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go
  where
    go :: SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go (Unconditionally BlockId
l)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` BlockId -> CmmNode O C
CmmBranch BlockId
l, [])
    go (JumpTable SwitchTargets
ids)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmExpr -> SwitchTargets -> CmmNode O C
CmmSwitch CmmExpr
expr SwitchTargets
ids, [])
    go (IfLT Bool
signed Integer
i SwitchPlan
ids1 SwitchPlan
ids2)
      = do
        (BlockId
bid1, [CmmBlock]
newBlocks1) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids1
        (BlockId
bid2, [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2

        let lt :: Platform -> CmmExpr -> CmmExpr -> CmmExpr
lt | Bool
signed    = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmSLtWord
               | Bool
otherwise = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmULtWord
            scrut :: CmmExpr
scrut = Platform -> CmmExpr -> CmmExpr -> CmmExpr
lt Platform
platform CmmExpr
expr forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit forall a b. (a -> b) -> a -> b
$ Platform -> Integer -> CmmLit
mkWordCLit Platform
platform Integer
i
            lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid1 BlockId
bid2 forall a. Maybe a
Nothing
            lastBlock :: Block CmmNode O C
lastBlock = forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
        forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks1forall a. [a] -> [a] -> [a]
++[CmmBlock]
newBlocks2)
    go (IfEqual Integer
i BlockId
l SwitchPlan
ids2)
      = do
        (BlockId
bid2, [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2

        let scrut :: CmmExpr
scrut = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmNeWord Platform
platform CmmExpr
expr forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit forall a b. (a -> b) -> a -> b
$ Platform -> Integer -> CmmLit
mkWordCLit Platform
platform Integer
i
            lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid2 BlockId
l forall a. Maybe a
Nothing
            lastBlock :: Block CmmNode O C
lastBlock = forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock forall (n :: Extensibility -> Extensibility -> *)
       (e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
        forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks2)

    -- Same but returning a label to branch to
    go' :: SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' (Unconditionally BlockId
l)
      = forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
l, [])
    go' SwitchPlan
p
      = do
        BlockId
bid <- Unique -> BlockId
mkBlockId forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
        (Block CmmNode O C
last, [CmmBlock]
newBlocks) <- SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go SwitchPlan
p
        let block :: CmmBlock
block = BlockId -> CmmTickScope -> CmmNode C O
CmmEntry BlockId
bid CmmTickScope
scope forall (n :: Extensibility -> Extensibility -> *)
       (x :: Extensibility).
n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O C
last
        forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
bid, CmmBlock
blockforall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks)