{-# LANGUAGE CPP #-}

#include "fusion-phases.h"

module Data.Array.Parallel.Lifted.Scalar
where

import Data.Array.Parallel.Lifted.PArray
import Data.Array.Parallel.Lifted.Unboxed
import Data.Array.Parallel.Lifted.Repr
import Data.Array.Parallel.Lifted.Instances
import Data.Array.Parallel.Lifted.Selector

import qualified Data.Array.Parallel.Unlifted as U

import Data.Array.Parallel.Base ((:*:)(..), fstS, pairS, unpairS,
                                 fromBool, toBool)

import GHC.Exts ( Int(..), (-#) )
import GHC.Word ( Word8 )

class U.Elt a => Scalar a where
  fromUArrPD :: U.Array a -> PData a
  toUArrPD   :: PData a -> U.Array a
  primPA     :: PA a

fromUArrPA :: Scalar a => Int -> U.Array a -> PArray a
{-# INLINE fromUArrPA #-}
fromUArrPA (I# n#) xs = PArray n# (fromUArrPD xs)

toUArrPA :: Scalar a => PArray a -> U.Array a
{-# INLINE toUArrPA #-}
toUArrPA (PArray _ xs) = toUArrPD xs

prim_lengthPA :: Scalar a => PArray a -> Int
{-# INLINE prim_lengthPA #-}
prim_lengthPA xs = I# (lengthPA# xs)

fromUArrPA' :: Scalar a => U.Array a -> PArray a
{-# INLINE fromUArrPA' #-}
fromUArrPA' xs = fromUArrPA (U.length xs) xs

scalar_map :: (Scalar a, Scalar b) => (a -> b) -> PArray a -> PArray b
{-# INLINE_PA scalar_map #-}
scalar_map f xs = fromUArrPA (prim_lengthPA xs)
                . U.map f
                $ toUArrPA xs

scalar_zipWith :: (Scalar a, Scalar b, Scalar c)
               => (a -> b -> c) -> PArray a -> PArray b -> PArray c
{-# INLINE_PA scalar_zipWith #-}
scalar_zipWith f xs ys = fromUArrPA (prim_lengthPA xs)
                       $ U.zipWith f (toUArrPA xs) (toUArrPA ys)

scalar_zipWith3
  :: (Scalar a, Scalar b, Scalar c, Scalar d)
  => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
{-# INLINE_PA scalar_zipWith3 #-}
scalar_zipWith3 f xs ys zs
  = fromUArrPA (prim_lengthPA xs)
  $ U.zipWith3 f (toUArrPA xs) (toUArrPA ys) (toUArrPA zs)

scalar_fold :: Scalar a => (a -> a -> a) -> a -> PArray a -> a
{-# INLINE_PA scalar_fold #-}
scalar_fold f z = U.fold f z . toUArrPA

scalar_fold1 :: Scalar a => (a -> a -> a) -> PArray a -> a
{-# INLINE_PA scalar_fold1 #-}
scalar_fold1 f = U.fold1 f . toUArrPA

scalar_folds :: Scalar a => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a
{-# INLINE_PA scalar_folds #-}
scalar_folds f z xss = fromUArrPA (prim_lengthPA (concatPA# xss))
                     . U.fold_s f z (segdPA# xss)
                     . toUArrPA
                     $ concatPA# xss

scalar_fold1s :: Scalar a => (a -> a -> a) -> PArray (PArray a) -> PArray a
{-# INLINE_PA scalar_fold1s #-}
scalar_fold1s f xss = fromUArrPA (prim_lengthPA (concatPA# xss))
                    . U.fold1_s f (segdPA# xss)
                    . toUArrPA
                    $ concatPA# xss

scalar_fold1Index :: Scalar a
                  => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int
{-# INLINE_PA scalar_fold1Index #-}
scalar_fold1Index f = fstS . U.fold1 f' . U.indexed . toUArrPA
  where
    {-# INLINE f' #-}
    f' p q = pairS $ f (unpairS p) (unpairS q)

scalar_fold1sIndex :: Scalar a
                   => ((Int, a) -> (Int, a) -> (Int, a))
                   -> PArray (PArray a) -> PArray Int
{-# INLINE_PA scalar_fold1sIndex #-}
scalar_fold1sIndex f xss = fromUArrPA n
                         . U.fsts
                         . U.fold1_s f' segd
                         . U.zip (U.indices_s m segd n)
                         . toUArrPA
                         $ concatPA# xss
  where
    {-# INLINE f' #-}
    f' p q = pairS $ f (unpairS p) (unpairS q)

    m = I# (lengthPA# xss)
    n = I# (lengthPA# (concatPA# xss))

    segd = segdPA# xss

instance Scalar Int where
  fromUArrPD xs        = PInt xs
  toUArrPD   (PInt xs) = xs
  primPA = dPA_Int

instance Scalar Word8 where
  fromUArrPD  xs         = PWord8 xs
  toUArrPD   (PWord8 xs) = xs
  primPA = dPA_Word8

instance Scalar Double where
  fromUArrPD  xs          = PDouble xs
  toUArrPD   (PDouble xs) = xs
  primPA = dPA_Double

instance Scalar Bool where
  {-# INLINE fromUArrPD #-}
  fromUArrPD bs
    = PBool (tagsToSel2 (U.map fromBool bs))

  {-# INLINE toUArrPD #-}
  toUArrPD (PBool sel) = U.map toBool (tagsSel2 sel)

  primPA = dPA_Bool


fromUArrPA_2 :: (Scalar a, Scalar b) => Int -> U.Array (a :*: b) -> PArray (a,b)
{-# INLINE fromUArrPA_2 #-}
fromUArrPA_2 (I# n#) ps = PArray n# (P_2 (fromUArrPD xs) (fromUArrPD  ys))
  where
    xs :*: ys = U.unzip ps

fromUArrPA_2' :: (Scalar a, Scalar b) => U.Array (a :*: b) -> PArray (a, b)
{-# INLINE fromUArrPA_2' #-}
fromUArrPA_2' ps = fromUArrPA_2 (U.length ps) ps

fromUArrPA_3 :: (Scalar a, Scalar b, Scalar c)
             => Int -> U.Array (a :*: b :*: c) -> PArray (a,b,c)
{-# INLINE fromUArrPA_3 #-}
fromUArrPA_3 (I# n#) ps = PArray n# (P_3 (fromUArrPD xs)
                                         (fromUArrPD ys)
                                         (fromUArrPD zs))
  where
    xs :*: ys :*: zs = U.unzip3 ps

fromUArrPA_3' :: (Scalar a, Scalar b, Scalar c) => U.Array (a :*: b :*: c) -> PArray (a, b, c)
{-# INLINE fromUArrPA_3' #-}
fromUArrPA_3' ps = fromUArrPA_3 (U.length ps) ps

nestUSegdPA :: Int -> U.Segd -> PArray a -> PArray (PArray a)
{-# INLINE nestUSegdPA #-}
nestUSegdPA (I# n#) segd (PArray _ xs) = PArray n# (PNested segd xs)

nestUSegdPA' :: U.Segd -> PArray a -> PArray (PArray a)
{-# INLINE nestUSegdPA' #-}
nestUSegdPA' segd xs = nestUSegdPA (U.lengthSegd segd) segd xs