#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Arrays (
lengthD, splitLenD, splitLengthD,
splitAsD, splitD, joinLengthD, joinD, splitJoinD,
splitSegdD, splitSD,
permuteD, bpermuteD, atomicUpdateD,
Distribution, balanced, unbalanced
) where
import Data.Array.Parallel.Base (
(:*:)(..), fstS, sndS, ST, runST)
import Data.Array.Parallel.Unlifted.Sequential
import Data.Array.Parallel.Unlifted.Distributed.Gang (
Gang, gangSize, seqGang)
import Data.Array.Parallel.Unlifted.Distributed.DistST (
stToDistST)
import Data.Array.Parallel.Unlifted.Distributed.Types (
DT, Dist, indexD, lengthD, newD, writeMD, zipD, unzipD, fstD, sndD,
elementsUSegdD,
checkGangD)
import Data.Array.Parallel.Unlifted.Distributed.Basics
import Data.Array.Parallel.Unlifted.Distributed.Combinators (
mapD, zipWithD, scanD, mapAccumLD,
zipWithDST_, mapDST_)
import Data.Array.Parallel.Unlifted.Distributed.Scalars (
sumD)
here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." ++ s
data Distribution
balanced :: Distribution
balanced = error $ here "balanced: touched"
unbalanced :: Distribution
unbalanced = error $ here "unbalanced: touched"
splitLengthD :: UA a => Gang -> UArr a -> Dist Int
splitLengthD g = splitLenD g . lengthU
splitLenD :: Gang -> Int -> Dist Int
splitLenD g !n = newD g (`fill` 0)
where
p = gangSize g
l = n `div` p
m = n `mod` p
fill md i | i < m = writeMD md i (l+1) >> fill md (i+1)
| i < p = writeMD md i l >> fill md (i+1)
| otherwise = return ()
splitAsD :: UA a => Gang -> Dist Int -> UArr a -> Dist (UArr a)
splitAsD g dlen !arr = zipWithD (seqGang g) (sliceU arr) is dlen
where
is = fstS $ scanD g (+) 0 dlen
splitD :: UA a => Gang -> Distribution -> UArr a -> Dist (UArr a)
splitD g _ !arr = zipWithD (seqGang g) (sliceU arr) is dlen
where
dlen = splitLengthD g arr
is = fstS $ scanD g (+) 0 dlen
joinLengthD :: UA a => Gang -> Dist (UArr a) -> Int
joinLengthD g = sumD g . lengthD
joinD :: UA a => Gang -> Distribution -> Dist (UArr a) -> UArr a
joinD g _ !darr = checkGangD (here "joinD") g darr $
newU n (\ma -> zipWithDST_ g (copy ma) di darr)
where
di :*: n = scanD g (+) 0 $ lengthD darr
copy ma i arr = stToDistST (copyMU ma i arr)
splitJoinD :: (UA a, UA b)
=> Gang -> (Dist (UArr a) -> Dist (UArr b)) -> UArr a -> UArr b
splitJoinD g f !xs = joinD g unbalanced (f (splitD g unbalanced xs))
joinDM :: UA a => Gang -> Dist (UArr a) -> ST s (MUArr a s)
joinDM g darr = checkGangD (here "joinDM") g darr $
do
marr <- newMU n
zipWithDST_ g (copy marr) di darr
return marr
where
di :*: n = scanD g (+) 0 $ lengthD darr
copy ma i arr = stToDistST (copyMU ma i arr)
permuteD :: UA a => Gang -> Dist (UArr a) -> Dist (UArr Int) -> UArr a
permuteD g darr dis = newU n (\ma -> zipWithDST_ g (permute ma) darr dis)
where
n = joinLengthD g darr
permute ma arr is = stToDistST (permuteMU ma arr is)
bpermuteD :: UA a => Gang -> UArr a -> Dist (UArr Int) -> Dist (UArr a)
bpermuteD g !as ds = mapD g (bpermuteU as) ds
atomicUpdateD :: UA a
=> Gang -> Dist (UArr a) -> Dist (UArr (Int :*: a)) -> UArr a
atomicUpdateD g darr upd = runST (
do
marr <- joinDM g darr
mapDST_ g (update marr) upd
unsafeFreezeAllMU marr
)
where
update marr arr = stToDistST (atomicUpdateMU marr arr)
splitSegdD :: Gang -> USegd -> Dist USegd
splitSegdD g !segd = mapD g lengthsToUSegd
$ splitAsD g d lens
where
d = sndS
. mapAccumLD g chunk 0
. splitLenD g
$ elementsUSegd segd
n = lengthUSegd segd
lens = lengthsUSegd segd
chunk i k = let j = go i k
in j :*: (ji)
go i k | i >= n = i
| m == 0 = go (i+1) k
| k <= 0 = i
| otherwise = go (i+1) (km)
where
m = lens !: i
joinSegD :: Gang -> Dist USegd -> USegd
joinSegD g = lengthsToUSegd
. joinD g unbalanced
. mapD (seqGang g) lengthsUSegd
splitSD :: UA a => Gang -> Dist USegd -> UArr a -> Dist (UArr a)
splitSD g dsegd xs = splitAsD g (elementsUSegdD dsegd) xs