{-# OPTIONS -fno-warn-incomplete-patterns #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Array.Parallel.Unlifted.Distributed.Types -- Copyright : (c) 2006 Roman Leshchinskiy -- License : see libraries/ndp/LICENSE -- -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au> -- Stability : experimental -- Portability : non-portable (GHC Extensions) -- -- Distributed types. -- module Data.Array.Parallel.Unlifted.Distributed.Types ( -- * Distributed types DT, Dist, MDist, -- * Operations on immutable distributed types indexD, unitD, zipD, unzipD, fstD, sndD, lengthD, newD, segdSD, concatSD, -- * Operations on mutable distributed types newMD, readMD, writeMD, unsafeFreezeMD, -- * Assertions checkGangD, checkGangMD, -- * Debugging functions sizeD, sizeMD ) where import Data.Array.Parallel.Unlifted.Distributed.Gang ( Gang, gangSize ) import Data.Array.Parallel.Arr import Data.Array.Parallel.Unlifted.Sequential import Data.Array.Parallel.Base import Data.Word (Word8) import Control.Monad (liftM, liftM2) infixl 9 `indexD` here s = "Data.Array.Parallel.Unlifted.Distributed.Types." ++ s -- |Distributed types -- ---------------------------- -- | Class of distributable types. Instances of 'DT' can be -- distributed across all workers of a 'Gang'. All such types -- must be hyperstrict as we do not want to pass thunks into distributed -- computations. class DT a where -- data Dist a -- data MDist a s -- | Extract a single element of an immutable distributed value. indexD :: Dist a -> Int -> a -- | Create an unitialised distributed value for the given 'Gang'. -- The gang is used (only) to know how many elements are needed -- in the distributed value. newMD :: Gang -> ST s (MDist a s) -- | Extract an element from a mutable distributed value. readMD :: MDist a s -> Int -> ST s a -- | Write an element of a mutable distributed value. writeMD :: MDist a s -> Int -> a -> ST s () -- | Unsafely freeze a mutable distributed value. unsafeFreezeMD :: MDist a s -> ST s (Dist a) -- GADTs TO REPLACE ATs FOR THE MOMENT data Dist a where DUnit :: !Int -> Dist () DPrim :: !(BUArr a) -> Dist a DProd :: !(Dist a) -> !(Dist b) -> Dist (a :*: b) DMaybe :: !(Dist Bool) -> !(Dist a) -> Dist (MaybeS a) -- The Dist Int redundantly records the size of the UArrs -- (redundantly because the UArrs also contain their sizes) DUArr :: !(Dist Int) -> !(BBArr (UArr a)) -> Dist (UArr a) -- NOTE: comments here were de-haddockized by Waern, because -- GADTS currently can't be documented DUSegd :: -- Local segment descriptors !(Dist (UArr (Int :*: Int))) -- Indicates whether the first local segment is -- split across two processors. -- -> !(Dist Bool) -> Dist USegd DSUArr :: !(Dist USegd) -> !(Dist (UArr a)) -> Dist (SUArr a) data MDist a s where MDUnit :: !Int -> MDist () s MDPrim :: !(MBUArr s a) -> MDist a s MDProd :: !(MDist a s) -> !(MDist b s) -> MDist (a :*: b) s MDMaybe :: !(MDist Bool s) -> !(MDist a s) -> MDist (MaybeS a) s MDUArr :: !(MDist Int s) -> !(MBBArr s (UArr a)) -> MDist (UArr a) s MDUSegd :: !(MDist (UArr (Int :*: Int)) s) -- -> !(MDist Bool s) -> MDist USegd s MDSUArr :: !(MDist USegd s) -> !(MDist (UArr a) s) -> MDist (SUArr a) s unDPrim :: UAE a => Dist a -> BUArr a unDPrim (DPrim a) = a unMDPrim :: UAE a => MDist a s -> MBUArr s a unMDPrim (MDPrim ma) = ma -- Distributing hyperstrict types may not change their strictness. instance (HS a, DT a) => HS (Dist a) -- | Number of elements in the distributed value. This is for debugging only -- and not a method of 'DT'. sizeD :: Dist a -> Int sizeD (DUnit n) = n sizeD (DPrim a) = lengthBU a sizeD (DProd x y) = sizeD x sizeD (DUArr _ a) = lengthBB a sizeD (DMaybe b a) = sizeD b sizeD (DUSegd ps) = sizeD ps sizeD (DSUArr segd _) = sizeD segd -- | Check that the sizes of the 'Gang' and of the distributed value match. checkGangD :: DT a => String -> Gang -> Dist a -> b -> b checkGangD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeD d) v -- | Number of elements in the mutable distributed value. This is for debugging -- only and is thus not a method of 'DT'. sizeMD :: MDist a s -> Int sizeMD (MDUnit n) = n sizeMD (MDPrim a) = lengthMBU a sizeMD (MDProd x y) = sizeMD x sizeMD (MDUArr _ ma) = lengthMBB ma sizeMD (MDMaybe b a) = sizeMD b -- | Check that the sizes of the 'Gang' and of the mutable distributed value -- match. checkGangMD :: DT a => String -> Gang -> MDist a s -> b -> b checkGangMD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeMD d) v -- Show instance (for debugging only) instance (Show a, DT a) => Show (Dist a) where show d = show (map (indexD d) [0 .. sizeD d - 1]) -- | 'DT' instances -- ---------------- instance DT () where indexD (DUnit n) i = check (here "indexD[()]") n i $ () newMD = return . MDUnit . gangSize readMD (MDUnit n) i = check (here "readMD[()]") n i $ return () writeMD (MDUnit n) i () = check (here "writeMD[()]") n i $ return () unsafeFreezeMD (MDUnit n) = return $ DUnit n primIndexD :: UAE a => Dist a -> Int -> a primIndexD = indexBU . unDPrim primNewMD :: UAE a => Gang -> ST s (MDist a s) primNewMD = liftM MDPrim . newMBU . gangSize primReadMD :: UAE a => MDist a s -> Int -> ST s a primReadMD = readMBU . unMDPrim primWriteMD :: UAE a => MDist a s -> Int -> a -> ST s () primWriteMD = writeMBU . unMDPrim primUnsafeFreezeMD :: UAE a => MDist a s -> ST s (Dist a) primUnsafeFreezeMD = liftM DPrim . unsafeFreezeAllMBU . unMDPrim instance DT Bool where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance DT Char where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance DT Int where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance DT Word8 where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance DT Float where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance DT Double where indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD instance (DT a, DT b) => DT (a :*: b) where indexD d i = (fstD d `indexD` i) :*: (sndD d `indexD` i) newMD g = liftM2 MDProd (newMD g) (newMD g) readMD (MDProd xs ys) i = liftM2 (:*:) (readMD xs i) (readMD ys i) writeMD (MDProd xs ys) i (x :*: y) = writeMD xs i x >> writeMD ys i y unsafeFreezeMD (MDProd xs ys) = liftM2 DProd (unsafeFreezeMD xs) (unsafeFreezeMD ys) instance DT a => DT (MaybeS a) where indexD (DMaybe bs as) i | bs `indexD` i = JustS $ as `indexD` i | otherwise = NothingS newMD g = liftM2 MDMaybe (newMD g) (newMD g) readMD (MDMaybe bs as) i = do b <- readMD bs i if b then liftM JustS $ readMD as i else return NothingS writeMD (MDMaybe bs as) i NothingS = writeMD bs i False writeMD (MDMaybe bs as) i (JustS x) = writeMD bs i True >> writeMD as i x unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs) (unsafeFreezeMD as) instance UA a => DT (UArr a) where indexD (DUArr _ a) i = indexBB a i newMD g = liftM2 MDUArr (newMD g) (newMBB (gangSize g)) readMD (MDUArr _ marr) = readMBB marr writeMD (MDUArr mlen marr) i a = do writeMD mlen i (lengthU a) writeMBB marr i a unsafeFreezeMD (MDUArr len a) = liftM2 DUArr (unsafeFreezeMD len) (unsafeFreezeAllMBB a) instance DT USegd where indexD (DUSegd d) = toUSegd . indexD d newMD g = liftM MDUSegd (newMD g) readMD (MDUSegd d) = liftM toUSegd . readMD d writeMD (MDUSegd d) i segd = writeMD d i (fromUSegd segd) unsafeFreezeMD (MDUSegd d) = liftM DUSegd (unsafeFreezeMD d) instance UA a => DT (SUArr a) where indexD (DSUArr dsegd da) i = indexD dsegd i >: indexD da i newMD g = liftM2 MDSUArr (newMD g) (newMD g) readMD (MDSUArr msegd ma) i = liftM2 (>:) (readMD msegd i) (readMD ma i) writeMD (MDSUArr msegd ma) i sarr = do writeMD msegd i (segdSU sarr) writeMD ma i (concatSU sarr) unsafeFreezeMD (MDSUArr msegd ma) = liftM2 DSUArr (unsafeFreezeMD msegd) (unsafeFreezeMD ma) -- |Basic operations on immutable distributed types -- ------------------------------------------- newD :: DT a => Gang -> (forall s . MDist a s -> ST s ()) -> Dist a newD g init = runST (do mdt <- newMD g init mdt unsafeFreezeMD mdt) -- | Yield a distributed unit. unitD :: Gang -> Dist () unitD = DUnit . gangSize -- | Pairing of distributed values. -- /The two values must belong to the same/ 'Gang'. zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a :*: b) {-# INLINE [1] zipD #-} zipD !x !y = checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y) $ DProd x y -- | Unpairing of distributed values. unzipD :: (DT a, DT b) => Dist (a :*: b) -> Dist a :*: Dist b unzipD (DProd dx dy) = dx :*: dy -- | Extract the first elements of a distributed pair. fstD :: (DT a, DT b) => Dist (a :*: b) -> Dist a fstD = fstS . unzipD -- | Extract the second elements of a distributed pair. sndD :: (DT a, DT b) => Dist (a :*: b) -> Dist b sndD = sndS . unzipD -- | Yield the distributed length of a distributed array. lengthD :: UA a => Dist (UArr a) -> Dist Int lengthD (DUArr l _) = l -- | Yield the distributed segment descriptor of a distributed segmented -- array. segdSD :: UA a => Dist (SUArr a) -> Dist USegd segdSD (DSUArr dsegd _) = dsegd -- | Flatten a distributed segmented array. concatSD :: UA a => Dist (SUArr a) -> Dist (UArr a) concatSD (DSUArr _ darr) = darr