module Hoopl.Dataflow
( C, O, Block
, lastNode, entryLabel
, foldNodesBwdOO
, foldRewriteNodesBwdOO
, DataflowLattice(..), OldFact(..), NewFact(..), JoinedFact(..)
, TransferFun, RewriteFun
, Fact, FactBase
, getFact, mkFactBase
, analyzeCmmFwd, analyzeCmmBwd
, rewriteCmmBwd
, changedIf
, joinOutFacts
, joinFacts
)
where
import GhcPrelude
import Cmm
import UniqSupply
import Data.Array
import Data.Maybe
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Hoopl.Block
import Hoopl.Graph
import Hoopl.Collections
import Hoopl.Label
type family Fact x f :: *
type instance Fact C f = FactBase f
type instance Fact O f = f
newtype OldFact a = OldFact a
newtype NewFact a = NewFact a
data JoinedFact a
= Changed !a
| NotChanged !a
getJoined :: JoinedFact a -> a
getJoined (Changed a) = a
getJoined (NotChanged a) = a
changedIf :: Bool -> a -> JoinedFact a
changedIf True = Changed
changedIf False = NotChanged
type JoinFun a = OldFact a -> NewFact a -> JoinedFact a
data DataflowLattice a = DataflowLattice
{ fact_bot :: a
, fact_join :: JoinFun a
}
data Direction = Fwd | Bwd
type TransferFun f = CmmBlock -> FactBase f -> FactBase f
type RewriteFun f = CmmBlock -> FactBase f -> UniqSM (CmmBlock, FactBase f)
analyzeCmmBwd, analyzeCmmFwd
:: DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmmBwd = analyzeCmm Bwd
analyzeCmmFwd = analyzeCmm Fwd
analyzeCmm
:: Direction
-> DataflowLattice f
-> TransferFun f
-> CmmGraph
-> FactBase f
-> FactBase f
analyzeCmm dir lattice transfer cmmGraph initFact =
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap =
case hooplGraph of
GMany NothingO bm NothingO -> bm
in fixpointAnalysis dir lattice transfer entry blockMap initFact
fixpointAnalysis
:: forall f.
Direction
-> DataflowLattice f
-> TransferFun f
-> Label
-> LabelMap CmmBlock
-> FactBase f
-> FactBase f
fixpointAnalysis direction lattice do_block entry blockmap = loop start
where
blocks = sortBlocks direction entry blockmap
num_blocks = length blocks
block_arr = listArray (0, num_blocks 1) blocks
start = IntSet.fromDistinctAscList
[0 .. num_blocks 1]
dep_blocks = mkDepBlocks direction blocks
join = fact_join lattice
loop
:: IntHeap
-> FactBase f
-> FactBase f
loop todo !fbase1 | Just (index, todo1) <- IntSet.minView todo =
let block = block_arr ! index
out_facts = do_block block fbase1
(todo2, fbase2) =
mapFoldlWithKey
(updateFact join dep_blocks) (todo1, fbase1) out_facts
in loop todo2 fbase2
loop _ !fbase1 = fbase1
rewriteCmmBwd
:: DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmmBwd = rewriteCmm Bwd
rewriteCmm
:: Direction
-> DataflowLattice f
-> RewriteFun f
-> CmmGraph
-> FactBase f
-> UniqSM (CmmGraph, FactBase f)
rewriteCmm dir lattice rwFun cmmGraph initFact = do
let entry = g_entry cmmGraph
hooplGraph = g_graph cmmGraph
blockMap1 =
case hooplGraph of
GMany NothingO bm NothingO -> bm
(blockMap2, facts) <-
fixpointRewrite dir lattice rwFun entry blockMap1 initFact
return (cmmGraph {g_graph = GMany NothingO blockMap2 NothingO}, facts)
fixpointRewrite
:: forall f.
Direction
-> DataflowLattice f
-> RewriteFun f
-> Label
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
fixpointRewrite dir lattice do_block entry blockmap = loop start blockmap
where
blocks = sortBlocks dir entry blockmap
num_blocks = length blocks
block_arr =
listArray (0, num_blocks 1) blocks
start =
IntSet.fromDistinctAscList [0 .. num_blocks 1]
dep_blocks = mkDepBlocks dir blocks
join = fact_join lattice
loop
:: IntHeap
-> LabelMap CmmBlock
-> FactBase f
-> UniqSM (LabelMap CmmBlock, FactBase f)
loop todo !blocks1 !fbase1
| Just (index, todo1) <- IntSet.minView todo = do
let block = block_arr ! index
(new_block, out_facts) <-
do_block block fbase1
let blocks2 = mapInsert (entryLabel new_block) new_block blocks1
(todo2, fbase2) =
mapFoldlWithKey
(updateFact join dep_blocks) (todo1, fbase1) out_facts
loop todo2 blocks2 fbase2
loop _ !blocks1 !fbase1 = return (blocks1, fbase1)
sortBlocks
:: NonLocal n
=> Direction -> Label -> LabelMap (Block n C C) -> [Block n C C]
sortBlocks direction entry blockmap =
case direction of
Fwd -> fwd
Bwd -> reverse fwd
where
fwd = revPostorderFrom blockmap entry
mkDepBlocks :: Direction -> [CmmBlock] -> LabelMap IntSet
mkDepBlocks Fwd blocks = go blocks 0 mapEmpty
where
go [] !_ !dep_map = dep_map
go (b:bs) !n !dep_map =
go bs (n + 1) $ mapInsert (entryLabel b) (IntSet.singleton n) dep_map
mkDepBlocks Bwd blocks = go blocks 0 mapEmpty
where
go [] !_ !dep_map = dep_map
go (b:bs) !n !dep_map =
let insert m l = mapInsertWith IntSet.union l (IntSet.singleton n) m
in go bs (n + 1) $ foldl' insert dep_map (successors b)
updateFact
:: JoinFun f
-> LabelMap IntSet
-> (IntHeap, FactBase f)
-> Label
-> f
-> (IntHeap, FactBase f)
updateFact fact_join dep_blocks (todo, fbase) lbl new_fact
= case lookupFact lbl fbase of
Nothing ->
let !z = mapInsert lbl new_fact fbase in (changed, z)
Just old_fact ->
case fact_join (OldFact old_fact) (NewFact new_fact) of
(NotChanged _) -> (todo, fbase)
(Changed f) -> let !z = mapInsert lbl f fbase in (changed, z)
where
changed = todo `IntSet.union`
mapFindWithDefault IntSet.empty lbl dep_blocks
getFact :: DataflowLattice f -> Label -> FactBase f -> f
getFact lat l fb = case lookupFact l fb of Just f -> f
Nothing -> fact_bot lat
joinOutFacts :: (NonLocal n) => DataflowLattice f -> n e C -> FactBase f -> f
joinOutFacts lattice nonLocal fact_base = foldl' join (fact_bot lattice) facts
where
join new old = getJoined $ fact_join lattice (OldFact old) (NewFact new)
facts =
[ fromJust fact
| s <- successors nonLocal
, let fact = lookupFact s fact_base
, isJust fact
]
joinFacts :: DataflowLattice f -> [f] -> f
joinFacts lattice facts = foldl' join (fact_bot lattice) facts
where
join new old = getJoined $ fact_join lattice (OldFact old) (NewFact new)
mkFactBase :: DataflowLattice f -> [(Label, f)] -> FactBase f
mkFactBase lattice = foldl' add mapEmpty
where
join = fact_join lattice
add result (l, f1) =
let !newFact =
case mapLookup l result of
Nothing -> f1
Just f2 -> getJoined $ join (OldFact f1) (NewFact f2)
in mapInsert l newFact result
foldNodesBwdOO :: (CmmNode O O -> f -> f) -> Block CmmNode O O -> f -> f
foldNodesBwdOO funOO = go
where
go (BCat b1 b2) f = go b1 $! go b2 f
go (BSnoc h n) f = go h $! funOO n f
go (BCons n t) f = funOO n $! go t f
go (BMiddle n) f = funOO n f
go BNil f = f
foldRewriteNodesBwdOO
:: forall f.
(CmmNode O O -> f -> UniqSM (Block CmmNode O O, f))
-> Block CmmNode O O
-> f
-> UniqSM (Block CmmNode O O, f)
foldRewriteNodesBwdOO rewriteOO initBlock initFacts = go initBlock initFacts
where
go (BCons node1 block1) !fact1 = (rewriteOO node1 `comp` go block1) fact1
go (BSnoc block1 node1) !fact1 = (go block1 `comp` rewriteOO node1) fact1
go (BCat blockA1 blockB1) !fact1 = (go blockA1 `comp` go blockB1) fact1
go (BMiddle node) !fact1 = rewriteOO node fact1
go BNil !fact = return (BNil, fact)
comp rew1 rew2 = \f1 -> do
(b, f2) <- rew2 f1
(a, !f3) <- rew1 f2
let !c = joinBlocksOO a b
return (c, f3)
joinBlocksOO :: Block n O O -> Block n O O -> Block n O O
joinBlocksOO BNil b = b
joinBlocksOO b BNil = b
joinBlocksOO (BMiddle n) b = blockCons n b
joinBlocksOO b (BMiddle n) = blockSnoc b n
joinBlocksOO b1 b2 = BCat b1 b2
type IntHeap = IntSet