module Data.Array.Parallel.Unlifted.Distributed.Types (
DT, Dist, MDist,
indexD, unitD, zipD, unzipD, fstD, sndD, lengthD,
newD, segdSD, concatSD,
newMD, readMD, writeMD, unsafeFreezeMD,
checkGangD, checkGangMD,
sizeD, sizeMD
) where
import Data.Array.Parallel.Unlifted.Distributed.Gang (
Gang, gangSize )
import Data.Array.Parallel.Arr
import Data.Array.Parallel.Unlifted.Sequential
import Data.Array.Parallel.Base
import Data.Word (Word8)
import Control.Monad (liftM, liftM2)
infixl 9 `indexD`
here s = "Data.Array.Parallel.Unlifted.Distributed.Types." ++ s
class DT a where
indexD :: Dist a -> Int -> a
newMD :: Gang -> ST s (MDist a s)
readMD :: MDist a s -> Int -> ST s a
writeMD :: MDist a s -> Int -> a -> ST s ()
unsafeFreezeMD :: MDist a s -> ST s (Dist a)
data Dist a where
DUnit :: !Int -> Dist ()
DPrim :: !(BUArr a) -> Dist a
DProd :: !(Dist a) -> !(Dist b) -> Dist (a :*: b)
DMaybe :: !(Dist Bool) -> !(Dist a) -> Dist (MaybeS a)
DUArr :: !(Dist Int) -> !(BBArr (UArr a)) -> Dist (UArr a)
DUSegd ::
!(Dist (UArr (Int :*: Int)))
-> Dist USegd
DSUArr :: !(Dist USegd) -> !(Dist (UArr a)) -> Dist (SUArr a)
data MDist a s where
MDUnit :: !Int -> MDist () s
MDPrim :: !(MBUArr s a) -> MDist a s
MDProd :: !(MDist a s) -> !(MDist b s) -> MDist (a :*: b) s
MDMaybe :: !(MDist Bool s) -> !(MDist a s) -> MDist (MaybeS a) s
MDUArr :: !(MDist Int s) -> !(MBBArr s (UArr a)) -> MDist (UArr a) s
MDUSegd :: !(MDist (UArr (Int :*: Int)) s)
-> MDist USegd s
MDSUArr :: !(MDist USegd s) -> !(MDist (UArr a) s) -> MDist (SUArr a) s
unDPrim :: UAE a => Dist a -> BUArr a
unDPrim (DPrim a) = a
unMDPrim :: UAE a => MDist a s -> MBUArr s a
unMDPrim (MDPrim ma) = ma
instance (HS a, DT a) => HS (Dist a)
sizeD :: Dist a -> Int
sizeD (DUnit n) = n
sizeD (DPrim a) = lengthBU a
sizeD (DProd x y) = sizeD x
sizeD (DUArr _ a) = lengthBB a
sizeD (DMaybe b a) = sizeD b
sizeD (DUSegd ps) = sizeD ps
sizeD (DSUArr segd _) = sizeD segd
checkGangD :: DT a => String -> Gang -> Dist a -> b -> b
checkGangD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeD d) v
sizeMD :: MDist a s -> Int
sizeMD (MDUnit n) = n
sizeMD (MDPrim a) = lengthMBU a
sizeMD (MDProd x y) = sizeMD x
sizeMD (MDUArr _ ma) = lengthMBB ma
sizeMD (MDMaybe b a) = sizeMD b
checkGangMD :: DT a => String -> Gang -> MDist a s -> b -> b
checkGangMD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeMD d) v
instance (Show a, DT a) => Show (Dist a) where
show d = show (map (indexD d) [0 .. sizeD d 1])
instance DT () where
indexD (DUnit n) i = check (here "indexD[()]") n i $ ()
newMD = return . MDUnit . gangSize
readMD (MDUnit n) i = check (here "readMD[()]") n i $
return ()
writeMD (MDUnit n) i () = check (here "writeMD[()]") n i $
return ()
unsafeFreezeMD (MDUnit n) = return $ DUnit n
primIndexD :: UAE a => Dist a -> Int -> a
primIndexD = indexBU . unDPrim
primNewMD :: UAE a => Gang -> ST s (MDist a s)
primNewMD = liftM MDPrim . newMBU . gangSize
primReadMD :: UAE a => MDist a s -> Int -> ST s a
primReadMD = readMBU . unMDPrim
primWriteMD :: UAE a => MDist a s -> Int -> a -> ST s ()
primWriteMD = writeMBU . unMDPrim
primUnsafeFreezeMD :: UAE a => MDist a s -> ST s (Dist a)
primUnsafeFreezeMD = liftM DPrim . unsafeFreezeAllMBU . unMDPrim
instance DT Bool where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance DT Char where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance DT Int where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance DT Word8 where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance DT Float where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance DT Double where
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
instance (DT a, DT b) => DT (a :*: b) where
indexD d i = (fstD d `indexD` i) :*: (sndD d `indexD` i)
newMD g = liftM2 MDProd (newMD g) (newMD g)
readMD (MDProd xs ys) i = liftM2 (:*:) (readMD xs i) (readMD ys i)
writeMD (MDProd xs ys) i (x :*: y)
= writeMD xs i x >> writeMD ys i y
unsafeFreezeMD (MDProd xs ys)
= liftM2 DProd (unsafeFreezeMD xs)
(unsafeFreezeMD ys)
instance DT a => DT (MaybeS a) where
indexD (DMaybe bs as) i
| bs `indexD` i = JustS $ as `indexD` i
| otherwise = NothingS
newMD g = liftM2 MDMaybe (newMD g) (newMD g)
readMD (MDMaybe bs as) i =
do
b <- readMD bs i
if b then liftM JustS $ readMD as i
else return NothingS
writeMD (MDMaybe bs as) i NothingS = writeMD bs i False
writeMD (MDMaybe bs as) i (JustS x) = writeMD bs i True
>> writeMD as i x
unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs)
(unsafeFreezeMD as)
instance UA a => DT (UArr a) where
indexD (DUArr _ a) i = indexBB a i
newMD g = liftM2 MDUArr (newMD g) (newMBB (gangSize g))
readMD (MDUArr _ marr) = readMBB marr
writeMD (MDUArr mlen marr) i a =
do
writeMD mlen i (lengthU a)
writeMBB marr i a
unsafeFreezeMD (MDUArr len a) = liftM2 DUArr (unsafeFreezeMD len)
(unsafeFreezeAllMBB a)
instance DT USegd where
indexD (DUSegd d) = toUSegd . indexD d
newMD g = liftM MDUSegd (newMD g)
readMD (MDUSegd d) = liftM toUSegd . readMD d
writeMD (MDUSegd d) i segd = writeMD d i (fromUSegd segd)
unsafeFreezeMD (MDUSegd d) = liftM DUSegd (unsafeFreezeMD d)
instance UA a => DT (SUArr a) where
indexD (DSUArr dsegd da) i = indexD dsegd i >: indexD da i
newMD g = liftM2 MDSUArr (newMD g) (newMD g)
readMD (MDSUArr msegd ma) i = liftM2 (>:) (readMD msegd i)
(readMD ma i)
writeMD (MDSUArr msegd ma) i sarr =
do
writeMD msegd i (segdSU sarr)
writeMD ma i (concatSU sarr)
unsafeFreezeMD (MDSUArr msegd ma) = liftM2 DSUArr (unsafeFreezeMD msegd)
(unsafeFreezeMD ma)
newD :: DT a => Gang -> (forall s . MDist a s -> ST s ()) -> Dist a
newD g init =
runST (do
mdt <- newMD g
init mdt
unsafeFreezeMD mdt)
unitD :: Gang -> Dist ()
unitD = DUnit . gangSize
zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a :*: b)
zipD !x !y = checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y) $
DProd x y
unzipD :: (DT a, DT b) => Dist (a :*: b) -> Dist a :*: Dist b
unzipD (DProd dx dy) = dx :*: dy
fstD :: (DT a, DT b) => Dist (a :*: b) -> Dist a
fstD = fstS . unzipD
sndD :: (DT a, DT b) => Dist (a :*: b) -> Dist b
sndD = sndS . unzipD
lengthD :: UA a => Dist (UArr a) -> Dist Int
lengthD (DUArr l _) = l
segdSD :: UA a => Dist (SUArr a) -> Dist USegd
segdSD (DSUArr dsegd _) = dsegd
concatSD :: UA a => Dist (SUArr a) -> Dist (UArr a)
concatSD (DSUArr _ darr) = darr