#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Sequential.Flat.Permute (
permuteU, permuteMU, bpermuteU, bpermuteDftU, reverseU, updateU,
atomicUpdateMU
) where
import Data.Array.Parallel.Base (
ST, runST, (:*:)(..), Rebox(..))
import Data.Array.Parallel.Stream (
Step(..), Stream(..))
import Data.Array.Parallel.Unlifted.Sequential.Flat.UArr (
UA, UArr, MUArr,
lengthU, newU, newDynU, newMU, unsafeFreezeAllMU, writeMU,
sliceU)
import Data.Array.Parallel.Unlifted.Sequential.Flat.Stream (
streamU, unstreamMU)
import Data.Array.Parallel.Unlifted.Sequential.Flat.Basics (
(!:))
import Data.Array.Parallel.Unlifted.Sequential.Flat.Enum (
enumFromToU)
import Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
mapU)
permuteMU :: UA e => MUArr e s -> UArr e -> UArr Int -> ST s ()
permuteMU mpa arr is = permute 0
where
n = lengthU arr
permute i
| i == n = return ()
| otherwise = writeMU mpa (is!:i) (arr!:i) >> permute (i + 1)
permuteU :: UA e => UArr e -> UArr Int -> UArr e
permuteU arr is = newU (lengthU arr) $ \mpa -> permuteMU mpa arr is
bpermuteU :: UA e => UArr e -> UArr Int -> UArr e
bpermuteU !a = mapU (a!:)
bpermuteDftU :: UA e
=> Int
-> (Int -> e)
-> UArr (Int :*: e)
-> UArr e
bpermuteDftU n init = updateU (mapU init . enumFromToU 0 $ n1)
atomicUpdateMU :: UA e => MUArr e s -> UArr (Int :*: e) -> ST s ()
atomicUpdateMU marr upd = updateM writeMU marr (streamU upd)
updateM :: UA e => (MUArr e s -> Int -> e -> ST s ())
-> MUArr e s -> Stream (Int :*: e) -> ST s ()
updateM write marr (Stream next s _) = upd s
where
upd s = case next s of
Done -> return ()
Skip s' -> upd s'
Yield (i :*: x) s' -> do
write marr i x
upd s'
updateU :: UA e => UArr e -> UArr (Int :*: e) -> UArr e
updateU arr upd = update (streamU arr) (streamU upd)
update :: UA e => Stream e -> Stream (Int :*: e) -> UArr e
update s1@(Stream _ _ n) !s2 = newDynU n (\marr ->
do
i <- unstreamMU marr s1
updateM writeMU marr s2
return i
)
reverseU :: UA e => UArr e -> UArr e
reverseU = rev . streamU
rev :: UA e => Stream e -> UArr e
rev (Stream next s n) =
runST (do
marr <- newMU n
i <- fill marr
a <- unsafeFreezeAllMU marr
return $ sliceU a i (ni)
)
where
fill marr = fill0 s n
where
fill0 s !i = case next s of
Done -> return i
Skip s' -> s' `dseq` fill0 s' i
Yield x s' -> s' `dseq`
let i' = i1
in
do
writeMU marr i' x
fill0 s' i'