module GHC.Cmm.Switch.Implement
( cmmImplementSwitchPlans
)
where
import GHC.Prelude
import GHC.Driver.Backend
import GHC.Platform
import GHC.Cmm.Dataflow.Block
import GHC.Cmm.BlockId
import GHC.Cmm
import GHC.Cmm.Utils
import GHC.Cmm.Switch
import GHC.Types.Unique.Supply
import GHC.Utils.Monad (concatMapM)
cmmImplementSwitchPlans :: Backend -> Platform -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans backend platform g
| backendSupportsSwitch backend = return g
| otherwise = do
blocks' <- concatMapM (visitSwitches platform) (toBlockList g)
return $ ofBlockList (g_entry g) blocks'
visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches platform block
| (entry@(CmmEntry _ scope), middle, CmmSwitch vanillaExpr ids) <- blockSplit block
= do
let plan = createSwitchPlan ids
(assignSimple, simpleExpr) <- floatSwitchExpr platform vanillaExpr
(newTail, newBlocks) <- implementSwitchPlan platform scope simpleExpr plan
let block' = entry `blockJoinHead` middle `blockAppend` assignSimple `blockAppend` newTail
return $ block' : newBlocks
| otherwise
= return [block]
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr _ reg@(CmmReg {}) = return (emptyBlock, reg)
floatSwitchExpr platform expr = do
(assign, expr') <- cmmMkAssign platform expr <$> getUniqueM
return (BMiddle assign, expr')
implementSwitchPlan :: Platform -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan platform scope expr = go
where
go (Unconditionally l)
= return (emptyBlock `blockJoinTail` CmmBranch l, [])
go (JumpTable ids)
= return (emptyBlock `blockJoinTail` CmmSwitch expr ids, [])
go (IfLT signed i ids1 ids2)
= do
(bid1, newBlocks1) <- go' ids1
(bid2, newBlocks2) <- go' ids2
let lt | signed = cmmSLtWord
| otherwise = cmmULtWord
scrut = lt platform expr $ CmmLit $ mkWordCLit platform i
lastNode = CmmCondBranch scrut bid1 bid2 Nothing
lastBlock = emptyBlock `blockJoinTail` lastNode
return (lastBlock, newBlocks1++newBlocks2)
go (IfEqual i l ids2)
= do
(bid2, newBlocks2) <- go' ids2
let scrut = cmmNeWord platform expr $ CmmLit $ mkWordCLit platform i
lastNode = CmmCondBranch scrut bid2 l Nothing
lastBlock = emptyBlock `blockJoinTail` lastNode
return (lastBlock, newBlocks2)
go' (Unconditionally l)
= return (l, [])
go' p
= do
bid <- mkBlockId `fmap` getUniqueM
(last, newBlocks) <- go p
let block = CmmEntry bid scope `blockJoinHead` last
return (bid, block: newBlocks)