{-# LANGUAGE GADTs #-}
module GHC.Cmm.Switch.Implement
( cmmImplementSwitchPlans
)
where
import GHC.Prelude
import GHC.Platform
import GHC.Cmm.Dataflow.Block
import GHC.Cmm.BlockId
import GHC.Cmm
import GHC.Cmm.Utils
import GHC.Cmm.Switch
import GHC.Types.Unique.Supply
import GHC.Driver.Session
import GHC.Utils.Monad (concatMapM)
cmmImplementSwitchPlans :: DynFlags -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans :: DynFlags -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans DynFlags
dflags CmmGraph
g
| HscTarget -> Bool
targetSupportsSwitch (DynFlags -> HscTarget
hscTarget DynFlags
dflags) = CmmGraph -> UniqSM CmmGraph
forall (m :: * -> *) a. Monad m => a -> m a
return CmmGraph
g
| Bool
otherwise = do
[CmmBlock]
blocks' <- (CmmBlock -> UniqSM [CmmBlock]) -> [CmmBlock] -> UniqSM [CmmBlock]
forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches (DynFlags -> Platform
targetPlatform DynFlags
dflags)) (CmmGraph -> [CmmBlock]
toBlockList CmmGraph
g)
CmmGraph -> UniqSM CmmGraph
forall (m :: * -> *) a. Monad m => a -> m a
return (CmmGraph -> UniqSM CmmGraph) -> CmmGraph -> UniqSM CmmGraph
forall a b. (a -> b) -> a -> b
$ BlockId -> [CmmBlock] -> CmmGraph
ofBlockList (CmmGraph -> BlockId
forall (n :: Extensibility -> Extensibility -> *).
GenCmmGraph n -> BlockId
g_entry CmmGraph
g) [CmmBlock]
blocks'
visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches :: Platform -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches Platform
platform CmmBlock
block
| (entry :: CmmNode C O
entry@(CmmEntry BlockId
_ CmmTickScope
scope), Block CmmNode O O
middle, CmmSwitch CmmExpr
vanillaExpr SwitchTargets
ids) <- CmmBlock -> (CmmNode C O, Block CmmNode O O, CmmNode O C)
forall (n :: Extensibility -> Extensibility -> *).
Block n C C -> (n C O, Block n O O, n O C)
blockSplit CmmBlock
block
= do
let plan :: SwitchPlan
plan = SwitchTargets -> SwitchPlan
createSwitchPlan SwitchTargets
ids
(Block CmmNode O O
assignSimple, CmmExpr
simpleExpr) <- Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr Platform
platform CmmExpr
vanillaExpr
(Block CmmNode O C
newTail, [CmmBlock]
newBlocks) <- Platform
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan Platform
platform CmmTickScope
scope CmmExpr
simpleExpr SwitchPlan
plan
let block' :: CmmBlock
block' = CmmNode C O
entry CmmNode C O -> Block CmmNode O O -> Block CmmNode C O
forall (n :: Extensibility -> Extensibility -> *)
(x :: Extensibility).
n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O O
middle Block CmmNode C O -> Block CmmNode O O -> Block CmmNode C O
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility) (x :: Extensibility).
Block n e O -> Block n O x -> Block n e x
`blockAppend` Block CmmNode O O
assignSimple Block CmmNode C O -> Block CmmNode O C -> CmmBlock
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility) (x :: Extensibility).
Block n e O -> Block n O x -> Block n e x
`blockAppend` Block CmmNode O C
newTail
[CmmBlock] -> UniqSM [CmmBlock]
forall (m :: * -> *) a. Monad m => a -> m a
return ([CmmBlock] -> UniqSM [CmmBlock])
-> [CmmBlock] -> UniqSM [CmmBlock]
forall a b. (a -> b) -> a -> b
$ CmmBlock
block' CmmBlock -> [CmmBlock] -> [CmmBlock]
forall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks
| Bool
otherwise
= [CmmBlock] -> UniqSM [CmmBlock]
forall (m :: * -> *) a. Monad m => a -> m a
return [CmmBlock
block]
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr :: Platform -> CmmExpr -> UniqSM (Block CmmNode O O, CmmExpr)
floatSwitchExpr Platform
_ reg :: CmmExpr
reg@(CmmReg {}) = (Block CmmNode O O, CmmExpr) -> UniqSM (Block CmmNode O O, CmmExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock, CmmExpr
reg)
floatSwitchExpr Platform
platform CmmExpr
expr = do
(CmmNode O O
assign, CmmExpr
expr') <- Platform -> CmmExpr -> Unique -> (CmmNode O O, CmmExpr)
cmmMkAssign Platform
platform CmmExpr
expr (Unique -> (CmmNode O O, CmmExpr))
-> UniqSM Unique -> UniqSM (CmmNode O O, CmmExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
(Block CmmNode O O, CmmExpr) -> UniqSM (Block CmmNode O O, CmmExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (CmmNode O O -> Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *).
n O O -> Block n O O
BMiddle CmmNode O O
assign, CmmExpr
expr')
implementSwitchPlan :: Platform -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan :: Platform
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan Platform
platform CmmTickScope
scope CmmExpr
expr = SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go
where
go :: SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go (Unconditionally BlockId
l)
= (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` BlockId -> CmmNode O C
CmmBranch BlockId
l, [])
go (JumpTable SwitchTargets
ids)
= (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmExpr -> SwitchTargets -> CmmNode O C
CmmSwitch CmmExpr
expr SwitchTargets
ids, [])
go (IfLT Bool
signed Integer
i SwitchPlan
ids1 SwitchPlan
ids2)
= do
(BlockId
bid1, [CmmBlock]
newBlocks1) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids1
(BlockId
bid2, [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2
let lt :: Platform -> CmmExpr -> CmmExpr -> CmmExpr
lt | Bool
signed = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmSLtWord
| Bool
otherwise = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmULtWord
scrut :: CmmExpr
scrut = Platform -> CmmExpr -> CmmExpr -> CmmExpr
lt Platform
platform CmmExpr
expr (CmmExpr -> CmmExpr) -> CmmExpr -> CmmExpr
forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit (CmmLit -> CmmExpr) -> CmmLit -> CmmExpr
forall a b. (a -> b) -> a -> b
$ Platform -> Integer -> CmmLit
mkWordCLit Platform
platform Integer
i
lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid1 BlockId
bid2 Maybe Bool
forall a. Maybe a
Nothing
lastBlock :: Block CmmNode O C
lastBlock = Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
(Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks1[CmmBlock] -> [CmmBlock] -> [CmmBlock]
forall a. [a] -> [a] -> [a]
++[CmmBlock]
newBlocks2)
go (IfEqual Integer
i BlockId
l SwitchPlan
ids2)
= do
(BlockId
bid2, [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2
let scrut :: CmmExpr
scrut = Platform -> CmmExpr -> CmmExpr -> CmmExpr
cmmNeWord Platform
platform CmmExpr
expr (CmmExpr -> CmmExpr) -> CmmExpr -> CmmExpr
forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit (CmmLit -> CmmExpr) -> CmmLit -> CmmExpr
forall a b. (a -> b) -> a -> b
$ Platform -> Integer -> CmmLit
mkWordCLit Platform
platform Integer
i
lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid2 BlockId
l Maybe Bool
forall a. Maybe a
Nothing
lastBlock :: Block CmmNode O C
lastBlock = Block CmmNode O O
forall (n :: Extensibility -> Extensibility -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: Extensibility -> Extensibility -> *)
(e :: Extensibility).
Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
(Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks2)
go' :: SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' (Unconditionally BlockId
l)
= (BlockId, [CmmBlock]) -> UniqSM (BlockId, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
l, [])
go' SwitchPlan
p
= do
BlockId
bid <- Unique -> BlockId
mkBlockId (Unique -> BlockId) -> UniqSM Unique -> UniqSM BlockId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
(Block CmmNode O C
last, [CmmBlock]
newBlocks) <- SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go SwitchPlan
p
let block :: CmmBlock
block = BlockId -> CmmTickScope -> CmmNode C O
CmmEntry BlockId
bid CmmTickScope
scope CmmNode C O -> Block CmmNode O C -> CmmBlock
forall (n :: Extensibility -> Extensibility -> *)
(x :: Extensibility).
n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O C
last
(BlockId, [CmmBlock]) -> UniqSM (BlockId, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
bid, CmmBlock
blockCmmBlock -> [CmmBlock] -> [CmmBlock]
forall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks)