module CFG
( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
, TransitionSource(..)
, addWeightEdge, addEdge
, delEdge, delNode
, addNodesBetween, shortcutWeightMap
, reverseEdges, filterEdges
, addImmediateSuccessor
, mkWeightInfo, adjustEdgeWeight, setEdgeWeight
, infoEdgeList, edgeList
, getSuccessorEdges, getSuccessors
, getSuccEdgesSorted
, getEdgeInfo
, getCfgNodes, hasNode
, loopMembers, loopLevels, loopInfo
--Construction/Misc
, getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
, optimizeCFG
, mkGlobalWeights
)
where
#include "HsVersions.h"
import GhcPrelude
import BlockId
import Cmm
import CmmUtils
import CmmSwitch
import Hoopl.Collections
import Hoopl.Label
import Hoopl.Block
import qualified Hoopl.Graph as G
import Util
import Digraph
import Maybes
import Unique
import qualified Dominators as Dom
import Data.IntMap.Strict (IntMap)
import Data.IntSet (IntSet)
import qualified Data.IntMap.Strict as IM
import qualified Data.Map as M
import qualified Data.IntSet as IS
import qualified Data.Set as S
import Data.Tree
import Data.Bifunctor
import Outputable
import PprCmm ()
import qualified DynFlags as D
import Data.List (sort, nub, partition)
import Data.STRef.Strict
import Control.Monad.ST
import Data.Array.MArray
import Data.Array.ST
import Data.Array.IArray
import Data.Array.Unsafe (unsafeFreeze)
import Data.Array.Base (unsafeRead, unsafeWrite)
import Control.Monad
type Prob = Double
type Edge = (BlockId, BlockId)
type Edges = [Edge]
newtype EdgeWeight
= EdgeWeight { weightToDouble :: Double }
deriving (Eq,Ord,Enum,Num,Real,Fractional)
instance Outputable EdgeWeight where
ppr (EdgeWeight w) = doublePrec 5 w
type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
type CFG = EdgeInfoMap EdgeInfo
data CfgEdge
= CfgEdge
{ edgeFrom :: !BlockId
, edgeTo :: !BlockId
, edgeInfo :: !EdgeInfo
}
instance Eq CfgEdge where
(==) (CfgEdge from1 to1 _) (CfgEdge from2 to2 _)
= from1 == from2 && to1 == to2
instance Ord CfgEdge where
compare (CfgEdge from1 to1 (EdgeInfo {edgeWeight = weight1}))
(CfgEdge from2 to2 (EdgeInfo {edgeWeight = weight2}))
| weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
weight1 == weight2 && from1 == from2 && to1 < to2
= LT
| from1 == from2 && to1 == to2 && weight1 == weight2
= EQ
| otherwise
= GT
instance Outputable CfgEdge where
ppr (CfgEdge from1 to1 edgeInfo)
= parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
data TransitionSource
= CmmSource { trans_cmmNode :: (CmmNode O C)
, trans_info :: BranchInfo }
| AsmCodeGen
deriving (Eq)
data BranchInfo = NoInfo
| HeapStackCheck
deriving Eq
instance Outputable BranchInfo where
ppr NoInfo = text "regular"
ppr HeapStackCheck = text "heap/stack"
isHeapOrStackCheck :: TransitionSource -> Bool
isHeapOrStackCheck (CmmSource { trans_info = HeapStackCheck}) = True
isHeapOrStackCheck _ = False
data EdgeInfo
= EdgeInfo
{ transitionSource :: !TransitionSource
, edgeWeight :: !EdgeWeight
} deriving (Eq)
instance Outputable EdgeInfo where
ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
mkWeightInfo :: EdgeWeight -> EdgeInfo
mkWeightInfo = EdgeInfo AsmCodeGen
adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
-> BlockId -> BlockId -> CFG
adjustEdgeWeight cfg f from to
| Just info <- getEdgeInfo from to cfg
, !weight <- edgeWeight info
, !newWeight <- f weight
= addEdge from to (info { edgeWeight = newWeight}) cfg
| otherwise = cfg
setEdgeWeight :: CFG -> EdgeWeight
-> BlockId -> BlockId -> CFG
setEdgeWeight cfg !weight from to
| Just info <- getEdgeInfo from to cfg
= addEdge from to (info { edgeWeight = weight}) cfg
| otherwise = cfg
getCfgNodes :: CFG -> [BlockId]
getCfgNodes m =
mapKeys m
hasNode :: CFG -> BlockId -> Bool
hasNode m node =
ASSERT( found || not (any (mapMember node) m))
found
where
found = mapMember node m
sanityCheckCfg :: CFG -> LabelSet -> SDoc -> Bool
sanityCheckCfg m blockSet msg
| blockSet == cfgNodes
= True
| otherwise =
pprPanic "Block list and cfg nodes don't match" (
text "difference:" <+> ppr diff $$
text "blocks:" <+> ppr blockSet $$
text "cfg:" <+> pprEdgeWeights m $$
msg )
False
where
cfgNodes = setFromList $ getCfgNodes m :: LabelSet
diff = (setUnion cfgNodes blockSet) `setDifference` (setIntersection cfgNodes blockSet) :: LabelSet
filterEdges :: (BlockId -> BlockId -> EdgeInfo -> Bool) -> CFG -> CFG
filterEdges f cfg =
mapMapWithKey filterSources cfg
where
filterSources from m =
mapFilterWithKey (\to w -> f from to w) m
shortcutWeightMap :: LabelMap (Maybe BlockId) -> CFG -> CFG
shortcutWeightMap cuts cfg =
foldl' applyMapping cfg $ mapToList cuts
where
applyMapping :: CFG -> (BlockId,Maybe BlockId) -> CFG
applyMapping m (from, Nothing) =
mapDelete from .
fmap (mapDelete from) $ m
applyMapping m (from, Just to) =
let updatedMap :: CFG
updatedMap
= fmap (shortcutEdge (from,to)) $
(mapDelete from m :: CFG )
in case mapLookup to cuts of
Nothing -> updatedMap
Just dest -> applyMapping updatedMap (to, dest)
shortcutEdge :: (BlockId, BlockId) -> LabelMap EdgeInfo -> LabelMap EdgeInfo
shortcutEdge (from, to) m =
case mapLookup from m of
Just info -> mapInsert to info $ mapDelete from m
Nothing -> m
addImmediateSuccessor :: BlockId -> BlockId -> CFG -> CFG
addImmediateSuccessor node follower cfg
= updateEdges . addWeightEdge node follower uncondWeight $ cfg
where
uncondWeight = fromIntegral . D.uncondWeight .
D.cfgWeightInfo $ D.unsafeGlobalDynFlags
targets = getSuccessorEdges cfg node
successors = map fst targets :: [BlockId]
updateEdges = addNewSuccs . remOldSuccs
remOldSuccs m = foldl' (flip (delEdge node)) m successors
addNewSuccs m =
foldl' (\m' (t,info) -> addEdge follower t info m') m targets
addEdge :: BlockId -> BlockId -> EdgeInfo -> CFG -> CFG
addEdge from to info cfg =
mapAlter addFromToEdge from $
mapAlter addDestNode to cfg
where
addFromToEdge Nothing = Just $ mapSingleton to info
addFromToEdge (Just wm) = Just $ mapInsert to info wm
addDestNode Nothing = Just $ mapEmpty
addDestNode n@(Just _) = n
addWeightEdge :: BlockId -> BlockId -> EdgeWeight -> CFG -> CFG
addWeightEdge from to weight cfg =
addEdge from to (mkWeightInfo weight) cfg
delEdge :: BlockId -> BlockId -> CFG -> CFG
delEdge from to m =
mapAlter remDest from m
where
remDest Nothing = Nothing
remDest (Just wm) = Just $ mapDelete to wm
delNode :: BlockId -> CFG -> CFG
delNode node cfg =
fmap (mapDelete node)
(mapDelete node cfg)
getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccEdgesSorted m bid =
let destMap = mapFindWithDefault mapEmpty bid m
cfgEdges = mapToList destMap
sortedEdges = sortWith (negate . edgeWeight . snd) cfgEdges
in
sortedEdges
getSuccessorEdges :: HasDebugCallStack => CFG -> BlockId -> [(BlockId,EdgeInfo)]
getSuccessorEdges m bid = maybe lookupError mapToList (mapLookup bid m)
where
lookupError = pprPanic "getSuccessorEdges: Block does not exist" $
ppr bid <+> pprEdgeWeights m
getEdgeInfo :: BlockId -> BlockId -> CFG -> Maybe EdgeInfo
getEdgeInfo from to m
| Just wm <- mapLookup from m
, Just info <- mapLookup to wm
= Just $! info
| otherwise
= Nothing
getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
getEdgeWeight cfg from to =
edgeWeight $ expectJust "Edgeweight for noexisting block" $
getEdgeInfo from to cfg
getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
getTransitionSource from to cfg = transitionSource $ expectJust "Source info for noexisting block" $
getEdgeInfo from to cfg
reverseEdges :: CFG -> CFG
reverseEdges cfg = mapFoldlWithKey (\cfg from toMap -> go (addNode cfg from) from toMap) mapEmpty cfg
where
addNode :: CFG -> BlockId -> CFG
addNode cfg b = mapInsertWith mapUnion b mapEmpty cfg
go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG
go cfg from toMap = mapFoldlWithKey (\cfg to info -> addEdge to from info cfg) cfg toMap :: CFG
infoEdgeList :: CFG -> [CfgEdge]
infoEdgeList m =
go (mapToList m) []
where
go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go [] acc = acc
go ((from,toMap):xs) acc
= go' xs from (mapToList toMap) acc
go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
go' froms _ [] acc = go froms acc
go' froms from ((to,info):tos) acc
= go' froms from tos (CfgEdge from to info : acc)
edgeList :: CFG -> [Edge]
edgeList m =
go (mapToList m) []
where
go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge]
go [] acc = acc
go ((from,toMap):xs) acc
= go' xs from (mapKeys toMap) acc
go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge]
go' froms _ [] acc = go froms acc
go' froms from (to:tos) acc
= go' froms from tos ((from,to) : acc)
getSuccessors :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
getSuccessors m bid
| Just wm <- mapLookup bid m
= mapKeys wm
| otherwise = lookupError
where
lookupError = pprPanic "getSuccessors: Block does not exist" $
ppr bid <+> pprEdgeWeights m
pprEdgeWeights :: CFG -> SDoc
pprEdgeWeights m =
let edges = sort $ infoEdgeList m :: [CfgEdge]
printEdge (CfgEdge from to (EdgeInfo { edgeWeight = weight }))
= text "\t" <> ppr from <+> text "->" <+> ppr to <>
text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
ppr weight <> text "\"];\n"
printNode node
= text "\t" <> ppr node <> text ";\n"
getEdgeNodes (CfgEdge from to _) = [from,to]
edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
in
text "digraph {\n" <>
(foldl' (<>) empty (map printEdge edges)) <>
(foldl' (<>) empty (map printNode nodes)) <>
text "}\n"
updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
updateEdgeWeight f (from, to) cfg
| Just oldInfo <- getEdgeInfo from to cfg
= let !oldWeight = edgeWeight oldInfo
!newWeight = f oldWeight
in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
| otherwise
= panic "Trying to update invalid edge"
mapWeights :: (BlockId -> BlockId -> EdgeWeight -> EdgeWeight) -> CFG -> CFG
mapWeights f cfg =
foldl' (\cfg (CfgEdge from to info) ->
let oldWeight = edgeWeight info
newWeight = f from to oldWeight
in addEdge from to (info {edgeWeight = newWeight}) cfg)
cfg (infoEdgeList cfg)
addNodesBetween :: CFG -> [(BlockId,BlockId,BlockId)] -> CFG
addNodesBetween m updates =
foldl' updateWeight m .
weightUpdates $ updates
where
weight = fromIntegral . D.uncondWeight .
D.cfgWeightInfo $ D.unsafeGlobalDynFlags
weightUpdates = map getWeight
getWeight :: (BlockId,BlockId,BlockId) -> (BlockId,BlockId,BlockId,EdgeInfo)
getWeight (from,between,old)
| Just edgeInfo <- getEdgeInfo from old m
= (from,between,old,edgeInfo)
| otherwise
= pprPanic "Can't find weight for edge that should have one" (
text "triple" <+> ppr (from,between,old) $$
text "updates" <+> ppr updates $$
text "cfg:" <+> pprEdgeWeights m )
updateWeight :: CFG -> (BlockId,BlockId,BlockId,EdgeInfo) -> CFG
updateWeight m (from,between,old,edgeInfo)
= addEdge from between edgeInfo .
addWeightEdge between old weight .
delEdge from old $ m
getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
getCfgProc _ (CmmData {}) = mapEmpty
getCfgProc weights (CmmProc _info _lab _live graph) = getCfg weights graph
getCfg :: D.CfgWeights -> CmmGraph -> CFG
getCfg weights graph =
foldl' insertEdge edgelessCfg $ concatMap getBlockEdges blocks
where
D.CFGWeights
{ D.uncondWeight = uncondWeight
, D.condBranchWeight = condBranchWeight
, D.switchWeight = switchWeight
, D.callWeight = callWeight
, D.likelyCondWeight = likelyCondWeight
, D.unlikelyCondWeight = unlikelyCondWeight
} = weights
edgelessCfg = mapFromList $ zip (map G.entryLabel blocks) (repeat mapEmpty)
insertEdge :: CFG -> ((BlockId,BlockId),EdgeInfo) -> CFG
insertEdge m ((from,to),weight) =
mapAlter f from m
where
f :: Maybe (LabelMap EdgeInfo) -> Maybe (LabelMap EdgeInfo)
f Nothing = Just $ mapSingleton to weight
f (Just destMap) = Just $ mapInsert to weight destMap
getBlockEdges :: CmmBlock -> [((BlockId,BlockId),EdgeInfo)]
getBlockEdges block =
case branch of
CmmBranch dest -> [mkEdge dest uncondWeight]
CmmCondBranch cond t f l
| l == Nothing ->
[mkEdge f condBranchWeight, mkEdge t condBranchWeight]
| l == Just True ->
[mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
| l == Just False ->
[mkEdge f likelyCondWeight, mkEdge t unlikelyCondWeight]
where
mkEdgeInfo =
EdgeInfo (CmmSource branch branchInfo) . fromIntegral
mkEdge target weight = ((bid,target), mkEdgeInfo weight)
branchInfo =
foldRegsUsed
(panic "foldRegsDynFlags")
(\info r -> if r == SpLim || r == HpLim || r == BaseReg
then HeapStackCheck else info)
NoInfo cond
(CmmSwitch _e ids) ->
let switchTargets = switchTargetsToList ids
adjustedWeight =
if (length switchTargets > 10) then 1 else switchWeight
in map (\x -> mkEdge x adjustedWeight) switchTargets
(CmmCall { cml_cont = Just cont}) -> [mkEdge cont callWeight]
(CmmForeignCall {Cmm.succ = cont}) -> [mkEdge cont callWeight]
(CmmCall { cml_cont = Nothing }) -> []
other ->
panic "Foo" $
ASSERT2(False, ppr "Unkown successor cause:" <>
(ppr branch <+> text "=>" <> ppr (G.successors other)))
map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
where
bid = G.entryLabel block
mkEdgeInfo = EdgeInfo (CmmSource branch NoInfo) . fromIntegral
mkEdge target weight = ((bid,target), mkEdgeInfo weight)
branch = lastNode block :: CmmNode O C
blocks = revPostorder graph :: [CmmBlock]
findBackEdges :: HasDebugCallStack => BlockId -> CFG -> Edges
findBackEdges root cfg =
map fst .
filter (\x -> snd x == Backward) $ typedEdges
where
edges = edgeList cfg :: [(BlockId,BlockId)]
getSuccs = getSuccessors cfg :: BlockId -> [BlockId]
typedEdges =
classifyEdges root getSuccs edges :: [((BlockId,BlockId),EdgeType)]
optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
optimizeCFG _ (CmmData {}) cfg = cfg
optimizeCFG weights (CmmProc info _lab _live graph) cfg =
favourFewerPreds .
penalizeInfoTables info .
increaseBackEdgeWeight (g_entry graph) $ cfg
where
increaseBackEdgeWeight :: BlockId -> CFG -> CFG
increaseBackEdgeWeight root cfg =
let backedges = findBackEdges root cfg
update weight
| weight <= 0 = 0
| otherwise
= weight + fromIntegral (D.backEdgeBonus weights)
in foldl' (\cfg edge -> updateEdgeWeight update edge cfg)
cfg backedges
penalizeInfoTables :: LabelMap a -> CFG -> CFG
penalizeInfoTables info cfg =
mapWeights fupdate cfg
where
fupdate :: BlockId -> BlockId -> EdgeWeight -> EdgeWeight
fupdate _ to weight
| mapMember to info
= weight (fromIntegral $ D.infoTablePenalty weights)
| otherwise = weight
favourFewerPreds :: CFG -> CFG
favourFewerPreds cfg =
let
revCfg =
reverseEdges $ filterEdges
(\_from -> fallthroughTarget) cfg
predCount n = length $ getSuccessorEdges revCfg n
nodes = getCfgNodes cfg
modifiers :: Int -> Int -> (EdgeWeight, EdgeWeight)
modifiers preds1 preds2
| preds1 < preds2 = ( 1,1)
| preds1 == preds2 = ( 0, 0)
| otherwise = (1, 1)
update :: CFG -> BlockId -> CFG
update cfg node
| [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
, !w1 <- edgeWeight e1
, !w2 <- edgeWeight e2
, w1 == w2
, (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
= (\cfg' ->
(adjustEdgeWeight cfg' (+mod2) node s2))
(adjustEdgeWeight cfg (+mod1) node s1)
| otherwise
= cfg
in foldl' update cfg nodes
where
fallthroughTarget :: BlockId -> EdgeInfo -> Bool
fallthroughTarget to (EdgeInfo source _weight)
| mapMember to info = False
| AsmCodeGen <- source = True
| CmmSource { trans_cmmNode = CmmBranch {} } <- source = True
| CmmSource { trans_cmmNode = CmmCondBranch {} } <- source = True
| otherwise = False
loopMembers :: HasDebugCallStack => CFG -> LabelMap Bool
loopMembers cfg =
foldl' (flip setLevel) mapEmpty sccs
where
mkNode :: BlockId -> Node BlockId BlockId
mkNode bid = DigraphNode bid bid (getSuccessors cfg bid)
nodes = map mkNode (getCfgNodes cfg)
sccs = stronglyConnCompFromEdgedVerticesOrd nodes
setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
setLevel (AcyclicSCC bid) m = mapInsert bid False m
setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids
loopLevels :: CFG -> BlockId -> LabelMap Int
loopLevels cfg root = liLevels loopInfos
where
loopInfos = loopInfo cfg root
data LoopInfo = LoopInfo
{ liBackEdges :: [(Edge)]
, liLevels :: LabelMap Int
, liLoops :: [(Edge, LabelSet)]
}
instance Outputable LoopInfo where
ppr (LoopInfo _ _lvls loops) =
text "Loops:(backEdge, bodyNodes)" $$
(vcat $ map ppr loops)
loopInfo :: HasDebugCallStack => CFG -> BlockId -> LoopInfo
loopInfo cfg root = LoopInfo { liBackEdges = backEdges
, liLevels = mapFromList loopCounts
, liLoops = loopBodies }
where
revCfg = reverseEdges cfg
graph =
fmap (setFromList . mapKeys ) cfg :: LabelMap LabelSet
rooted = ( fromBlockId root
, toIntMap $ fmap toIntSet graph) :: (Int, IntMap IntSet)
tree = fmap toBlockId $ Dom.domTree rooted :: Tree BlockId
domMap :: LabelMap LabelSet
domMap = mkDomMap tree
edges = edgeList cfg :: [(BlockId, BlockId)]
nodes = getCfgNodes cfg :: [BlockId]
isBackEdge (from,to)
| Just doms <- mapLookup from domMap
, setMember to doms
= True
| otherwise = False
findBody edge@(tail, head)
= ( edge, setInsert head $ go (setSingleton tail) (setSingleton tail) )
where
cfg' = delNode head revCfg
go :: LabelSet -> LabelSet -> LabelSet
go found current
| setNull current = found
| otherwise = go (setUnion newSuccessors found)
newSuccessors
where
newSuccessors = setFilter (\n -> not $ setMember n found) successors :: LabelSet
successors = setFromList $ concatMap
(getSuccessors cfg')
(filter (/= head) $ setElems current) :: LabelSet
backEdges = filter isBackEdge edges
loopBodies = map findBody backEdges :: [(Edge, LabelSet)]
loopCounts =
let bodies = map (first snd) loopBodies
loopCount n = length $ nub . map fst . filter (setMember n . snd) $ bodies
in map (\n -> (n, loopCount n)) $ nodes :: [(BlockId, Int)]
toIntSet :: LabelSet -> IntSet
toIntSet s = IS.fromList . map fromBlockId . setElems $ s
toIntMap :: LabelMap a -> IntMap a
toIntMap m = IM.fromList $ map (\(x,y) -> (fromBlockId x,y)) $ mapToList m
mkDomMap :: Tree BlockId -> LabelMap LabelSet
mkDomMap root = mapFromList $ go setEmpty root
where
go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)]
go parents (Node lbl [])
= [(lbl, parents)]
go parents (Node _ leaves)
= let nodes = map rootLabel leaves
entries = map (\x -> (x,parents)) nodes
in entries ++ concatMap
(\n -> go (setInsert (rootLabel n) parents) n)
leaves
fromBlockId :: BlockId -> Int
fromBlockId = getKey . getUnique
toBlockId :: Int -> BlockId
toBlockId = mkBlockId . mkUniqueGrimily
newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
instance G.NonLocal (BlockNode) where
entryLabel (BN (lbl,_)) = lbl
successors (BN (_,succs)) = succs
revPostorderFrom :: HasDebugCallStack => CFG -> BlockId -> [BlockId]
revPostorderFrom cfg root =
map fromNode $ G.revPostorderFrom hooplGraph root
where
nodes = getCfgNodes cfg
hooplGraph = foldl' (\m n -> mapInsert n (toNode n) m) mapEmpty nodes
fromNode :: BlockNode C C -> BlockId
fromNode (BN x) = fst x
toNode :: BlockId -> BlockNode C C
toNode bid =
BN (bid,getSuccessors cfg $ bid)
mkGlobalWeights :: HasDebugCallStack => BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
mkGlobalWeights root localCfg
| null localCfg = panic "Error - Empty CFG"
| otherwise
= (blockFreqs', edgeFreqs')
where
(blockFreqs, edgeFreqs) = calcFreqs nodeProbs backEdges' bodies' revOrder'
blockFreqs' = mapFromList $ map (first fromVertex) (assocs blockFreqs) :: LabelMap Double
edgeFreqs' = fmap fromVertexMap $ fromVertexMap edgeFreqs
fromVertexMap :: IM.IntMap x -> LabelMap x
fromVertexMap m = mapFromList . map (first fromVertex) $ IM.toList m
revOrder = revPostorderFrom localCfg root :: [BlockId]
loopResults@(LoopInfo backedges _levels bodies) = loopInfo localCfg root
revOrder' = map toVertex revOrder
backEdges' = map (bimap toVertex toVertex) backedges
bodies' = map calcBody bodies
estimatedCfg = staticBranchPrediction root loopResults localCfg
nodeProbs = cfgEdgeProbabilities estimatedCfg toVertex
calcBody (backedge, blocks) =
(toVertex $ snd backedge, sort . map toVertex $ (setElems blocks))
vertexMapping = mapFromList $ zip revOrder [0..] :: LabelMap Int
blockMapping = listArray (0,mapSize vertexMapping 1) revOrder :: Array Int BlockId
toVertex :: BlockId -> Int
toVertex blockId = expectJust "mkGlobalWeights" $ mapLookup blockId vertexMapping
fromVertex :: Int -> BlockId
fromVertex vertex = blockMapping ! vertex
type TargetNodeInfo = (BlockId, EdgeInfo)
staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
staticBranchPrediction _root (LoopInfo l_backEdges loopLevels l_loops) cfg =
foldl' update cfg nodes
where
nodes = getCfgNodes cfg
backedges = S.fromList $ l_backEdges
loops = M.fromList $ l_loops :: M.Map Edge LabelSet
loopHeads = S.fromList $ map snd $ M.keys loops
update :: CFG -> BlockId -> CFG
update cfg node
| null successors = cfg
| not (null m) && length m < length successors
, not $ any (isHeapOrStackCheck . transitionSource . snd) successors
= let loopChance = repeat $! pred_LBH / (fromIntegral $ length m)
exitChance = repeat $! (1 pred_LBH) / fromIntegral (length not_m)
updates = zip (map fst m) loopChance ++ zip (map fst not_m) exitChance
in
foldl' (\cfg (to,weight) -> setEdgeWeight cfg weight node to) cfg updates
| length successors /= 2
= cfg
| length m > 0
= cfg
| [(s1,s1_info),(s2,s2_info)] <- successors
, not $ any (isHeapOrStackCheck . transitionSource . snd) successors
=
let !w1 = max (edgeWeight s1_info) (0)
!w2 = max (edgeWeight s2_info) (0)
normalizeWeight w = if w1 + w2 == 0 then 0.5 else w/(w1+w2)
!cfg' = setEdgeWeight cfg (normalizeWeight w1) node s1
!cfg'' = setEdgeWeight cfg' (normalizeWeight w2) node s2
heuristics = map ($ ((s1,s1_info),(s2,s2_info)))
[lehPredicts, phPredicts, ohPredicts, ghPredicts, lhhPredicts, chPredicts
, shPredicts, rhPredicts]
applyHeuristic :: CFG -> Maybe Prob -> CFG
applyHeuristic cfg Nothing = cfg
applyHeuristic cfg (Just (s1_pred :: Double))
| s1_old == 0 || s2_old == 0 ||
isHeapOrStackCheck (transitionSource s1_info) ||
isHeapOrStackCheck (transitionSource s2_info)
= cfg
| otherwise =
let
s1_prob = EdgeWeight s1_pred :: EdgeWeight
s2_prob = 1.0 s1_prob
d = (s1_old * s1_prob) + (s2_old * s2_prob) :: EdgeWeight
s1_prob' = s1_old * s1_prob / d
!s2_prob' = s2_old * s2_prob / d
!cfg_s1 = setEdgeWeight cfg s1_prob' node s1
in
setEdgeWeight cfg_s1 s2_prob' node s2
where
s1_old = getEdgeWeight cfg node s1
s2_old = getEdgeWeight cfg node s2
in
foldl' applyHeuristic cfg'' heuristics
| otherwise = cfg
where
pred_LBH = 0.875
successors = getSuccessorEdges cfg node
(m,not_m) = partition (\succ -> S.member (node, fst succ) backedges) successors
pred_LEH = 0.75
lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob
lehPredicts ((s1,_s1_info),(s2,_s2_info))
| S.member s1 loopHeads || S.member s2 loopHeads
= Nothing
| otherwise
=
case compare s1Level s2Level of
EQ -> Nothing
LT -> Just (1pred_LEH)
GT -> Just (pred_LEH)
where
s1Level = mapLookup s1 loopLevels
s2Level = mapLookup s2 loopLevels
ohPredicts (s1,_s2)
| CmmSource { trans_cmmNode = src1 } <- getTransitionSource node (fst s1) cfg
, CmmCondBranch cond ltrue _lfalse likely <- src1
, likely == Nothing
, CmmMachOp mop args <- cond
, MO_Eq {} <- mop
, not (null [x | x@CmmLit{} <- args])
= if fst s1 == ltrue then Just 0.3 else Just 0.7
| otherwise
= Nothing
phPredicts = const Nothing
ghPredicts = const Nothing
lhhPredicts = const Nothing
chPredicts = const Nothing
shPredicts = const Nothing
rhPredicts = const Nothing
cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob)
cfgEdgeProbabilities cfg toVertex
= mapFoldlWithKey foldEdges IM.empty cfg
where
foldEdges = (\m from toMap -> IM.insert (toVertex from) (normalize toMap) m)
normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob)
normalize weightMap
| edgeCount <= 1 = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) 1.0 m) IM.empty weightMap
| otherwise = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) (normalWeight k) m) IM.empty weightMap
where
edgeCount = mapSize weightMap
minWeight = 0 :: Prob
weightMap' = fmap (\w -> max (weightToDouble . edgeWeight $ w) minWeight) weightMap
totalWeight = sum weightMap'
normalWeight :: BlockId -> Prob
normalWeight bid
| totalWeight == 0
= 1.0 / fromIntegral edgeCount
| Just w <- mapLookup bid weightMap'
= w/totalWeight
| otherwise = panic "impossible"
calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int]
-> (Array Int Double, IM.IntMap (IM.IntMap Prob))
calcFreqs graph backEdges loops revPostOrder = runST $ do
visitedNodes <- newArray (0,nodeCount1) False :: ST s (STUArray s Int Bool)
blockFreqs <- newArray (0,nodeCount1) 0.0 :: ST s (STUArray s Int Double)
edgeProbs <- newSTRef graph
edgeBackProbs <- newSTRef graph
let
visited b = unsafeRead visitedNodes b
getFreq b = unsafeRead blockFreqs b
setFreq b f = unsafeWrite blockFreqs b f
setVisited b = unsafeWrite visitedNodes b True
getProb' arr b1 b2 = readSTRef arr >>=
(\graph ->
return .
fromMaybe (error "getFreq 1") .
IM.lookup b2 .
fromMaybe (error "getFreq 2") $
(IM.lookup b1 graph)
)
setProb' arr b1 b2 prob = do
g <- readSTRef arr
let !m = fromMaybe (error "Foo") $ IM.lookup b1 g
!m' = IM.insert b2 prob m
writeSTRef arr $! (IM.insert b1 m' g)
getEdgeFreq b1 b2 = getProb' edgeProbs b1 b2
setEdgeFreq b1 b2 = setProb' edgeProbs b1 b2
getProb b1 b2 = fromMaybe (error "getProb") $ do
m' <- IM.lookup b1 graph
IM.lookup b2 m'
getBackProb b1 b2 = getProb' edgeBackProbs b1 b2
setBackProb b1 b2 = setProb' edgeBackProbs b1 b2
let
calcOutFreqs bhead block = do
!f <- getFreq block
forM (successors block) $ \bi -> do
let !prob = getProb block bi
let !succFreq = f * prob
setEdgeFreq block bi succFreq
when (bi == bhead) $ setBackProb block bi succFreq
let propFreq block head = do
!v <- visited block
if v then
return ()
else if block == head then
setFreq block 1.0
else do
let preds = IS.elems $ predecessors block
irreducible <- (fmap or) $ forM preds $ \bp -> do
!bp_visited <- visited bp
let bp_backedge = isBackEdge bp block
return (not bp_visited && not bp_backedge)
if irreducible
then return ()
else do
setFreq block 0
!cycleProb <- sum <$> (forM preds $ \pred -> do
if isBackEdge pred block
then
getBackProb pred block
else do
!f <- getFreq block
!prob <- getEdgeFreq pred block
setFreq block $! f + prob
return 0)
let limit = 1 1/512
!cycleProb <- return $ min cycleProb limit
!f <- getFreq block
setFreq block (f / (1.0 cycleProb))
setVisited block
calcOutFreqs head block
forM_ loops $ \(head, body) -> do
forM_ [0 .. nodeCount 1] (\i -> unsafeWrite visitedNodes i True)
forM_ body (\i -> unsafeWrite visitedNodes i False)
forM_ body $ \block -> propFreq block head
forM_ [0 .. nodeCount 1] (\i -> unsafeWrite visitedNodes i False)
forM_ revPostOrder $ \block -> propFreq block (head revPostOrder)
graph' <- readSTRef edgeProbs
freqs' <- unsafeFreeze blockFreqs
return (freqs', graph')
where
predecessors :: Int -> IS.IntSet
predecessors b = fromMaybe IS.empty $ IM.lookup b revGraph
successors :: Int -> [Int]
successors b = fromMaybe (lookupError "succ" b graph)$ IM.keys <$> IM.lookup b graph
lookupError s b g = pprPanic ("Lookup error " ++ s) $
( text "node" <+> ppr b $$
text "graph" <+>
vcat (map (\(k,m) -> ppr (k,m :: IM.IntMap Double)) $ IM.toList g)
)
nodeCount = IM.foldl' (\count toMap -> IM.foldlWithKey' countTargets count toMap) (IM.size graph) graph
where
countTargets = (\count k _ -> countNode k + count )
countNode n = if IM.member n graph then 0 else 1
isBackEdge from to = S.member (from,to) backEdgeSet
backEdgeSet = S.fromList backEdges
revGraph :: IntMap IntSet
revGraph = IM.foldlWithKey' (\m from toMap -> addEdges m from toMap) IM.empty graph
where
addEdges m0 from toMap = IM.foldlWithKey' (\m k _ -> addEdge m from k) m0 toMap
addEdge m0 from to = IM.insertWith IS.union to (IS.singleton from) m0