-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Array.Parallel.Unlifted.Distributed.Basics
-- Copyright   :  (c) 2006 Roman Leshchinskiy
-- License     :  see libraries/ndp/LICENSE
-- 
-- Maintainer  :  Roman Leshchinskiy <rl@cse.unsw.edu.au>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- Standard combinators for distributed types.
--

module Data.Array.Parallel.Unlifted.Distributed.Combinators (
  mapD, zipD, unzipD, fstD, sndD, zipWithD,
  foldD, scanD,

  -- * Monadic combinators
  mapDST_, mapDST, zipWithDST_, zipWithDST
) where

import Data.Array.Parallel.Base (
  (:*:)(..), uncurryS, unsafe_pairS, unsafe_unpairS, ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Gang (
  Gang, gangSize)
import Data.Array.Parallel.Unlifted.Distributed.Types (
  DT, Dist, indexD, zipD, unzipD, fstD, sndD,
  newMD, writeMD, unsafeFreezeMD,
  checkGangD)
import Data.Array.Parallel.Unlifted.Distributed.DistST (
  DistST, distST_, distST, runDistST, myD)

here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s

-- | Map a function over a distributed value.
mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
{-# INLINE [1] mapD #-}
mapD g f !d = checkGangD (here "mapD") g d
             (runDistST g (myD d >>= return . f))

{-# RULES

"mapD/mapD" forall gang f g d.
  mapD gang f (mapD gang g d) = mapD gang (\x -> f (g x)) d

"zipD/mapD[1]" forall gang f xs ys.
  zipD (mapD gang f xs) ys
    = mapD gang (unsafe_pairS . (\(xs, ys) -> (f xs, ys)) . unsafe_unpairS)
                (zipD xs ys)

"zipD/mapD[2]" forall gang f xs ys.
  zipD xs (mapD gang f ys)
    = mapD gang (unsafe_pairS . (\(xs, ys) -> (xs, f ys)) . unsafe_unpairS)
                (zipD xs ys)

  #-}

-- zipD, unzipD, fstD, sndD reexported from Types

-- | Combine two distributed values with the given function.
zipWithD :: (DT a, DT b, DT c)
         => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
{-# INLINE zipWithD #-}
zipWithD g f dx dy = mapD g (uncurry f . unsafe_unpairS) (zipD dx dy)

-- | Fold a distributed value.
foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
foldD g f d = checkGangD ("here foldD") g d $
              fold 1 (d `indexD` 0)
  where
    n = gangSize g
    --
    fold i x | i == n    = x
             | otherwise = fold (i+1) (f x $ d `indexD` i)

-- | Prefix sum of a distributed value.
scanD :: DT a => Gang -> (a -> a -> a) -> a -> Dist a -> Dist a :*: a
scanD g f z d = checkGangD (here "scanD") g d $
                runST (do
                  md <- newMD g
                  s  <- scan md 0 z
                  d' <- unsafeFreezeMD md
                  return (d' :*: s))
  where
    n = gangSize g
    --
    scan md i x | i == n    = return x
                | otherwise = do
                                writeMD md i x
                                scan md (i+1) (f x $ d `indexD` i)

-- NOTE: The following combinators must be strict in the Dists because if they
-- are not, the Dist might be evaluated (in parallel) when it is requested in
-- the current computation which, again, is parallel. This would break our
-- model andlead to a deadlock. Hence the bangs.

mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
mapDST_ g p !d = checkGangD (here "mapDST_") g d $
                 distST_ g (myD d >>= p)

mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
mapDST g p !d = checkGangD (here "mapDST_") g d $
                distST g (myD d >>= p)

zipWithDST_ :: (DT a, DT b)
            => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
zipWithDST_ g p !dx !dy = mapDST_ g (uncurryS p) (zipD dx dy)

zipWithDST :: (DT a, DT b, DT c)
           => Gang
           -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
zipWithDST g p !dx !dy = mapDST g (uncurryS p) (zipD dx dy)