{-# LANGUAGE GADTs #-}
module GHC.Cmm.Switch.Implement
  ( cmmImplementSwitchPlans
  )
where

import GHC.Prelude

import GHC.Driver.Backend
import GHC.Platform
import GHC.Cmm.Dataflow.Block
import GHC.Cmm.BlockId
import GHC.Cmm
import GHC.Cmm.Utils
import GHC.Cmm.Switch
import GHC.Types.Unique.Supply
import GHC.Utils.Monad (concatMapM)

--
-- This module replaces Switch statements as generated by the Stg -> Cmm
-- transformation, which might be huge and sparse and hence unsuitable for
-- assembly code, by proper constructs (if-then-else trees, dense jump tables).
--
-- The actual, abstract strategy is determined by createSwitchPlan in
-- GHC.Cmm.Switch and returned as a SwitchPlan; here is just the implementation in
-- terms of Cmm code. See Note [Cmm Switches, the general plan] in GHC.Cmm.Switch.
--
-- This division into different modules is both to clearly separate concerns,
-- but also because createSwitchPlan needs access to the constructors of
-- SwitchTargets, a data type exported abstractly by GHC.Cmm.Switch.
--

-- | Traverses the 'CmmGraph', making sure that 'CmmSwitch' are suitable for
-- code generation.
cmmImplementSwitchPlans :: Backend -> Platform -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans backend platform g
    -- Switch generation done by backend (LLVM/C)
    | backendSupportsSwitch backend = return g
    | otherwise = do
    blocks' <- concatMapM (visitSwitches platform) (toBlockList g)
    return $ ofBlockList (g_entry g) blocks'

visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches platform block
  | (entry@(CmmEntry _ scope), middle, CmmSwitch vanillaExpr ids) <- blockSplit block
  = do
    let plan = createSwitchPlan ids
    -- See Note [Floating switch expressions]
    (assignSimple, simpleExpr) <- floatSwitchExpr platform vanillaExpr

    (newTail, newBlocks) <- implementSwitchPlan platform scope simpleExpr plan

    let block' = entry `blockJoinHead` middle `blockAppend` assignSimple `blockAppend` newTail

    return $ block' : newBlocks

  | otherwise
  = return [block]

-- Note [Floating switch expressions]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

-- When we translate a sparse switch into a search tree we would like
-- to compute the value we compare against only once.

-- For this purpose we assign the switch expression to a local register
-- and then use this register when constructing the actual binary tree.

-- This is important as the expression could contain expensive code like
-- memory loads or divisions which we REALLY don't want to duplicate.

-- This happened in parts of the handwritten RTS Cmm code. See also #16933

-- See Note [Floating switch expressions]
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr _        reg@(CmmReg {})  = return (emptyBlock, reg)
floatSwitchExpr platform expr             = do
  (assign, expr') <- cmmMkAssign platform expr <$> getUniqueM
  return (BMiddle assign, expr')


-- Implementing a switch plan (returning a tail block)
implementSwitchPlan :: Platform -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan platform scope expr = go
  where
    go (Unconditionally l)
      = return (emptyBlock `blockJoinTail` CmmBranch l, [])
    go (JumpTable ids)
      = return (emptyBlock `blockJoinTail` CmmSwitch expr ids, [])
    go (IfLT signed i ids1 ids2)
      = do
        (bid1, newBlocks1) <- go' ids1
        (bid2, newBlocks2) <- go' ids2

        let lt | signed    = cmmSLtWord
               | otherwise = cmmULtWord
            scrut = lt platform expr $ CmmLit $ mkWordCLit platform i
            lastNode = CmmCondBranch scrut bid1 bid2 Nothing
            lastBlock = emptyBlock `blockJoinTail` lastNode
        return (lastBlock, newBlocks1++newBlocks2)
    go (IfEqual i l ids2)
      = do
        (bid2, newBlocks2) <- go' ids2

        let scrut = cmmNeWord platform expr $ CmmLit $ mkWordCLit platform i
            lastNode = CmmCondBranch scrut bid2 l Nothing
            lastBlock = emptyBlock `blockJoinTail` lastNode
        return (lastBlock, newBlocks2)

    -- Same but returning a label to branch to
    go' (Unconditionally l)
      = return (l, [])
    go' p
      = do
        bid <- mkBlockId `fmap` getUniqueM
        (last, newBlocks) <- go p
        let block = CmmEntry bid scope `blockJoinHead` last
        return (bid, block: newBlocks)