#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Arrays (
lengthD, splitLenD, splitLengthD, splitD, joinLengthD, joinD, splitJoinD,
splitSD, joinSD, splitJoinSD,
permuteD, bpermuteD, atomicUpdateD, bpermuteSD',
Distribution, balanced, unbalanced
) where
import Data.Array.Parallel.Base (
(:*:)(..), fstS, ST, runST)
import Data.Array.Parallel.Unlifted.Sequential
import Data.Array.Parallel.Unlifted.Distributed.Gang (
Gang, gangSize, seqGang)
import Data.Array.Parallel.Unlifted.Distributed.DistST (
stToDistST)
import Data.Array.Parallel.Unlifted.Distributed.Types (
DT, Dist, indexD, lengthD, newD, writeMD, zipD, unzipD,
segdSD, concatSD,
checkGangD)
import Data.Array.Parallel.Unlifted.Distributed.Basics
import Data.Array.Parallel.Unlifted.Distributed.Combinators (
mapD, zipWithD, scanD,
zipWithDST_, mapDST_)
import Data.Array.Parallel.Unlifted.Distributed.Scalars (
sumD)
here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." ++ s
data Distribution
balanced :: Distribution
balanced = error $ here "balanced: touched"
unbalanced :: Distribution
unbalanced = error $ here "unbalanced: touched"
splitLengthD :: UA a => Gang -> UArr a -> Dist Int
splitLengthD g = splitLenD g . lengthU
splitLenD :: Gang -> Int -> Dist Int
splitLenD g n = newD g (`fill` 0)
where
p = gangSize g
l = n `div` p
m = n `mod` p
fill md i | i < m = writeMD md i (l+1) >> fill md (i+1)
| i < p = writeMD md i l >> fill md (i+1)
| otherwise = return ()
splitAsD :: UA a => Gang -> Dist Int -> UArr a -> Dist (UArr a)
splitAsD g dlen !arr = zipWithD (seqGang g) (sliceU arr) is dlen
where
is = fstS $ scanD g (+) 0 dlen
splitD :: UA a => Gang -> Distribution -> UArr a -> Dist (UArr a)
splitD g _ !arr = zipWithD (seqGang g) (sliceU arr) is dlen
where
dlen = splitLengthD g arr
is = fstS $ scanD g (+) 0 dlen
joinLengthD :: UA a => Gang -> Dist (UArr a) -> Int
joinLengthD g = sumD g . lengthD
joinD :: UA a => Gang -> Distribution -> Dist (UArr a) -> UArr a
joinD g _ !darr = checkGangD (here "joinD") g darr $
newU n (\ma -> zipWithDST_ g (copy ma) di darr)
where
di :*: n = scanD g (+) 0 $ lengthD darr
copy ma i arr = stToDistST (copyMU ma i arr)
splitJoinD :: (UA a, UA b)
=> Gang -> (Dist (UArr a) -> Dist (UArr b)) -> UArr a -> UArr b
splitJoinD g f !xs = joinD g unbalanced (f (splitD g unbalanced xs))
joinDM :: UA a => Gang -> Dist (UArr a) -> ST s (MUArr a s)
joinDM g darr = checkGangD (here "joinDM") g darr $
do
marr <- newMU n
zipWithDST_ g (copy marr) di darr
return marr
where
di :*: n = scanD g (+) 0 $ lengthD darr
copy ma i arr = stToDistST (copyMU ma i arr)
permuteD :: UA a => Gang -> Dist (UArr a) -> Dist (UArr Int) -> UArr a
permuteD g darr dis = newU n (\ma -> zipWithDST_ g (permute ma) darr dis)
where
n = joinLengthD g darr
permute ma arr is = stToDistST (permuteMU ma arr is)
bpermuteD :: UA a => Gang -> UArr a -> Dist (UArr Int) -> Dist (UArr a)
bpermuteD g !as ds = mapD g (bpermuteU as) ds
atomicUpdateD :: UA a
=> Gang -> Dist (UArr a) -> Dist (UArr (Int :*: a)) -> UArr a
atomicUpdateD g darr upd = runST (
do
marr <- joinDM g darr
mapDST_ g (update marr) upd
unsafeFreezeAllMU marr
)
where
update marr arr = stToDistST (atomicUpdateMU marr arr)
splitSegdLengthsD :: Gang -> Int -> UArr Int -> Dist (Int :*: Int)
splitSegdLengthsD g n !lens = newD g (\md -> fill md 0 0 0 0)
where
m = lengthU lens
p = gangSize g
dlens = splitLenD g n
fill md i j k l | i == p = return ()
| (l < e || e == 0)
&& j < m = fill md i (j + 1)
(k + 1)
(l + lens !: j)
| otherwise = do
writeMD md i (k :*: l)
fill md (i + 1) j 0 0
where
e = dlens `indexD` i
splitSegdD' :: Gang -> Int -> USegd -> Dist (USegd :*: Int)
splitSegdD' g n !segd = zipD (mapD g lengthsToUSegd
$ splitAsD g segdlens lens) adlens
where
lens = lengthsUSegd segd
segdlens :*: adlens = unzipD (splitSegdLengthsD g n lens)
joinSegD :: Gang -> Dist USegd -> USegd
joinSegD g = lengthsToUSegd
. joinD g unbalanced
. mapD (seqGang g) lengthsUSegd
splitSD :: UA a => Gang -> Distribution -> SUArr a -> Dist (SUArr a)
splitSD g _ !sarr = zipWithD g (>:) dsegd (splitAsD g dlen flat)
where
flat = concatSU sarr
dsegd :*: dlen = unzipD (splitSegdD' g (lengthU flat) (segdSU sarr))
joinSD :: UA a => Gang -> Distribution -> Dist (SUArr a) -> SUArr a
joinSD g _ !darr = joinSegD g (segdSD darr)
>: joinD g unbalanced (concatSD darr)
splitJoinSD :: (UA a, UA b)
=> Gang -> (Dist (SUArr a) -> Dist (SUArr b)) -> SUArr a -> SUArr b
splitJoinSD g f !xs = joinSD g unbalanced (f (splitSD g unbalanced xs))
bpermuteSD' :: UA a => Gang -> UArr a -> Dist (SUArr Int) -> Dist (SUArr a)
bpermuteSD' g as = mapD g (bpermuteSU' as)