module Data.Array.Parallel.Unlifted.Sequential.Flat.Sums (
andU, orU, anyU, allU,
elemU, notElemU,
sumU, productU,
maximumU, minimumU, maximumByU, minimumByU,
maximumIndexU, minimumIndexU,
maximumIndexByU, minimumIndexByU,
lengthU'
) where
import Data.Array.Parallel.Base (
(:*:)(..), fstS)
import Data.Array.Parallel.Unlifted.Sequential.Flat.UArr (
UA, UArr)
import Data.Array.Parallel.Unlifted.Sequential.Flat.Basics (
indexedU)
import Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
mapU, foldU, fold1U, foldlU)
infix 4 `elemU`, `notElemU`
andU :: UArr Bool -> Bool
andU = foldU (&&) True
orU :: UArr Bool -> Bool
orU = foldU (||) False
allU :: UA e => (e -> Bool) -> UArr e -> Bool
allU p = andU . mapU p
anyU :: UA e => (e -> Bool) -> UArr e -> Bool
anyU p = orU . mapU p
sumU :: (Num e, UA e) => UArr e -> e
sumU = foldU (+) 0
productU :: (Num e, UA e) => UArr e -> e
productU = foldU (*) 1
maximumU :: (Ord e, UA e) => UArr e -> e
maximumU = fold1U max
maximumByU :: UA e => (e -> e -> Ordering) -> UArr e -> e
maximumByU = fold1U . maxBy
where
maxBy compare x y = case x `compare` y of
LT -> y
_ -> x
maximumIndexU :: (Ord e, UA e) => UArr e -> Int
maximumIndexU = maximumIndexByU compare
maximumIndexByU :: UA e => (e -> e -> Ordering) -> UArr e -> Int
maximumIndexByU cmp = fstS . maximumByU cmp' . indexedU
where
cmp' (_ :*: x) (_ :*: y) = cmp x y
minimumU :: (Ord e, UA e) => UArr e -> e
minimumU = fold1U min
minimumByU :: UA e => (e -> e -> Ordering) -> UArr e -> e
minimumByU = fold1U . minBy
where
minBy compare x y = case x `compare` y of
GT -> y
_ -> x
minimumIndexU :: (Ord e, UA e) => UArr e -> Int
minimumIndexU = minimumIndexByU compare
minimumIndexByU :: UA e => (e -> e -> Ordering) -> UArr e -> Int
minimumIndexByU cmp = fstS . minimumByU cmp' . indexedU
where
cmp' (_ :*: x) (_ :*: y) = cmp x y
elemU :: (Eq e, UA e) => e -> UArr e -> Bool
elemU e = anyU (== e)
notElemU :: (Eq e, UA e) => e -> UArr e -> Bool
notElemU e = allU (/= e)
lengthU' :: UA e => UArr e -> Int
lengthU' = foldlU (const . (+1)) 0