module GHC.CmmToAsm.Reg.Graph.SpillCost (
SpillCostRecord,
plusSpillCostRecord,
pprSpillCostRecord,
SpillCostInfo,
zeroSpillCostInfo,
plusSpillCostInfo,
slurpSpillCostInfo,
chooseSpill,
lifeMapFromSpillCostInfo
) where
import GHC.Prelude
import GHC.CmmToAsm.Reg.Liveness
import GHC.CmmToAsm.Instr
import GHC.Platform.Reg.Class
import GHC.Platform.Reg
import GHC.Data.Graph.Base
import GHC.Cmm.Dataflow.Collections (mapLookup)
import GHC.Cmm.Dataflow.Label
import GHC.Cmm
import GHC.Types.Unique.FM
import GHC.Types.Unique.Set
import GHC.Data.Graph.Directed (flattenSCCs)
import GHC.Utils.Outputable
import GHC.Platform
import GHC.Utils.Monad.State
import GHC.CmmToAsm.CFG
import Data.List (nub, minimumBy)
import Data.Maybe
import Control.Monad (join)
type SpillCostRecord
= ( VirtualReg
, Int
, Int
, Int)
type SpillCostInfo
= UniqFM VirtualReg SpillCostRecord
type SpillCostState = State SpillCostInfo ()
zeroSpillCostInfo :: SpillCostInfo
zeroSpillCostInfo = emptyUFM
plusSpillCostInfo :: SpillCostInfo -> SpillCostInfo -> SpillCostInfo
plusSpillCostInfo sc1 sc2
= plusUFM_C plusSpillCostRecord sc1 sc2
plusSpillCostRecord :: SpillCostRecord -> SpillCostRecord -> SpillCostRecord
plusSpillCostRecord (r1, a1, b1, c1) (r2, a2, b2, c2)
| r1 == r2 = (r1, a1 + a2, b1 + b2, c1 + c2)
| otherwise = error "RegSpillCost.plusRegInt: regs don't match"
slurpSpillCostInfo :: forall instr statics. (Outputable instr, Instruction instr)
=> Platform
-> Maybe CFG
-> LiveCmmDecl statics instr
-> SpillCostInfo
slurpSpillCostInfo platform cfg cmm
= execState (countCmm cmm) zeroSpillCostInfo
where
countCmm CmmData{} = return ()
countCmm (CmmProc info _ _ sccs)
= mapM_ (countBlock info freqMap)
$ flattenSCCs sccs
where
LiveInfo _ entries _ _ = info
freqMap = (fst . mkGlobalWeights (head entries)) <$> cfg
countBlock info freqMap (BasicBlock blockId instrs)
| LiveInfo _ _ blockLive _ <- info
, Just rsLiveEntry <- mapLookup blockId blockLive
, rsLiveEntry_virt <- takeVirtuals rsLiveEntry
= countLIs (ceiling $ blockFreq freqMap blockId) rsLiveEntry_virt instrs
| otherwise
= error "RegAlloc.SpillCost.slurpSpillCostInfo: bad block"
countLIs :: Int -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
countLIs _ _ []
= return ()
countLIs scale rsLive (LiveInstr instr Nothing : lis)
| isMetaInstr instr
= countLIs scale rsLive lis
| otherwise
= pprPanic "RegSpillCost.slurpSpillCostInfo"
$ text "no liveness information on instruction " <> ppr instr
countLIs scale rsLiveEntry (LiveInstr instr (Just live) : lis)
= do
mapM_ incLifetime $ nonDetEltsUniqSet rsLiveEntry
let (RU read written) = regUsageOfInstr platform instr
mapM_ (incUses scale) $ catMaybes $ map takeVirtualReg $ nub read
mapM_ (incDefs scale) $ catMaybes $ map takeVirtualReg $ nub written
let liveDieRead_virt = takeVirtuals (liveDieRead live)
let liveDieWrite_virt = takeVirtuals (liveDieWrite live)
let liveBorn_virt = takeVirtuals (liveBorn live)
let rsLiveAcross
= rsLiveEntry `minusUniqSet` liveDieRead_virt
let rsLiveNext
= (rsLiveAcross `unionUniqSets` liveBorn_virt)
`minusUniqSet` liveDieWrite_virt
countLIs scale rsLiveNext lis
incDefs count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, count, 0, 0)
incUses count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, count, 0)
incLifetime reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, 1)
blockFreq :: Maybe (LabelMap Double) -> Label -> Double
blockFreq freqs bid
| Just freq <- join (mapLookup bid <$> freqs)
= max 1.0 (10000 * freq)
| otherwise
= 1.0
takeVirtuals :: UniqSet Reg -> UniqSet VirtualReg
takeVirtuals set = mkUniqSet
[ vr | RegVirtual vr <- nonDetEltsUniqSet set ]
chooseSpill
:: SpillCostInfo
-> Graph VirtualReg RegClass RealReg
-> VirtualReg
chooseSpill info graph
= let cost = spillCost_length info graph
node = minimumBy (\n1 n2 -> compare (cost $ nodeId n1) (cost $ nodeId n2))
$ nonDetEltsUFM $ graphMap graph
in nodeId node
spillCost_length
:: SpillCostInfo
-> Graph VirtualReg RegClass RealReg
-> VirtualReg
-> Float
spillCost_length info _ reg
| lifetime <= 1 = 1/0
| otherwise = 1 / fromIntegral lifetime
where (_, _, _, lifetime)
= fromMaybe (reg, 0, 0, 0)
$ lookupUFM info reg
lifeMapFromSpillCostInfo :: SpillCostInfo -> UniqFM VirtualReg (VirtualReg, Int)
lifeMapFromSpillCostInfo info
= listToUFM
$ map (\(r, _, _, life) -> (r, (r, life)))
$ nonDetEltsUFM info
nodeDegree
:: (VirtualReg -> RegClass)
-> Graph VirtualReg RegClass RealReg
-> VirtualReg
-> Int
nodeDegree classOfVirtualReg graph reg
| Just node <- lookupUFM (graphMap graph) reg
, virtConflicts
<- length
$ filter (\r -> classOfVirtualReg r == classOfVirtualReg reg)
$ nonDetEltsUniqSet
$ nodeConflicts node
= virtConflicts + sizeUniqSet (nodeExclusions node)
| otherwise
= 0
pprSpillCostRecord
:: (VirtualReg -> RegClass)
-> (Reg -> SDoc)
-> Graph VirtualReg RegClass RealReg
-> SpillCostRecord
-> SDoc
pprSpillCostRecord regClass pprReg graph (reg, uses, defs, life)
= hsep
[ pprReg (RegVirtual reg)
, ppr uses
, ppr defs
, ppr life
, ppr $ nodeDegree regClass graph reg
, text $ show $ (fromIntegral (uses + defs)
/ fromIntegral (nodeDegree regClass graph reg) :: Float) ]