-- | Mapping of prelude functions to vectorised versions.
--     Functions like filterP currently have a working but naive version in GHC.PArr
--     During vectorisation we replace these by calls to filterPA, which are
--     defined in dph-common Data.Array.Parallel.Lifted.Combinators
--
--     As renamer only sees the GHC.PArr functions, if you want to add a new function
--     to the vectoriser there has to be a definition for it in GHC.PArr, even though
--     it will never be used at runtime.
--
module Vectorise.Builtins.Prelude
	( preludeVars
	, preludeScalars)
where
import Vectorise.Builtins.Modules
import PrelNames
import Module
import FastString


preludeVars
	:: Modules			-- ^ Modules containing the DPH backens
	-> [( Module, FastString	--   Maps the original variable to the one in the DPH 
	    , Module, FastString)]      --   packages that it should be rewritten to.

preludeVars (Modules { dph_Combinators    = dph_Combinators
                     , dph_PArray         = dph_PArray
                     , dph_Prelude_Int    = dph_Prelude_Int
                     , dph_Prelude_Word8  = dph_Prelude_Word8
                     , dph_Prelude_Double = dph_Prelude_Double
                     , dph_Prelude_Bool   = dph_Prelude_Bool 
                     , dph_Prelude_PArr   = dph_Prelude_PArr
                     })

    -- Functions that work on whole PArrays, defined in GHC.PArr
  = [ mk gHC_PARR (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
    , mk gHC_PARR (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
    , mk gHC_PARR (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
    , mk gHC_PARR (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
    , mk gHC_PARR (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
    , mk gHC_PARR (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
    , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
    , mk gHC_PARR (fsLit "!:")         dph_Combinators (fsLit "indexPA")
    , mk gHC_PARR (fsLit "sliceP")     dph_Combinators (fsLit "slicePA")
    , mk gHC_PARR (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
    , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
    , mk gHC_PARR (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
    , mk gHC_PARR (fsLit "+:+")        dph_Combinators (fsLit "appPA")
    , mk gHC_PARR (fsLit "emptyP")     dph_PArray      (fsLit "emptyPA")

    -- Map scalar functions to versions using closures. 
    , mk' dph_Prelude_Int "div"         "divV"
    , mk' dph_Prelude_Int "mod"         "modV"
    , mk' dph_Prelude_Int "sqrt"        "sqrtV"
    , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
    -- , mk' dph_Prelude_Int "upToP" "upToPA"
    ]
    ++ vars_Ord dph_Prelude_Int
    ++ vars_Num dph_Prelude_Int

    ++ vars_Ord dph_Prelude_Word8
    ++ vars_Num dph_Prelude_Word8
    ++
    [ mk' dph_Prelude_Word8 "div"     "divV"
    , mk' dph_Prelude_Word8 "mod"     "modV"
    , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
    , mk' dph_Prelude_Word8 "toInt"   "toIntV"
    ]

    ++ vars_Ord        dph_Prelude_Double
    ++ vars_Num        dph_Prelude_Double
    ++ vars_Fractional dph_Prelude_Double
    ++ vars_Floating   dph_Prelude_Double
    ++ vars_RealFrac   dph_Prelude_Double
    ++
    [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
    , mk dph_Prelude_Bool  (fsLit "orP")   dph_Prelude_Bool (fsLit "orPA")

    , mk gHC_CLASSES (fsLit "not")         dph_Prelude_Bool (fsLit "notV")
    , mk gHC_CLASSES (fsLit "&&")          dph_Prelude_Bool (fsLit "andV")
    , mk gHC_CLASSES (fsLit "||")          dph_Prelude_Bool (fsLit "orV")

    -- FIXME: temporary
    , mk dph_Prelude_PArr (fsLit "fromPArrayP")       dph_Prelude_PArr (fsLit "fromPArrayPA")
    , mk dph_Prelude_PArr (fsLit "toPArrayP")         dph_Prelude_PArr (fsLit "toPArrayPA")
    , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
    , mk dph_Prelude_PArr (fsLit "combineP")          dph_Combinators  (fsLit "combine2PA")
    , mk dph_Prelude_PArr (fsLit "updateP")           dph_Combinators  (fsLit "updatePA")
    , mk dph_Prelude_PArr (fsLit "bpermuteP")         dph_Combinators  (fsLit "bpermutePA")
    , mk dph_Prelude_PArr (fsLit "indexedP")          dph_Combinators  (fsLit "indexedPA")
    ]
  where
    mk  = (,,,)
    mk' mod v v' = mk mod (fsLit v) mod (fsLit v')

    vars_Ord mod 
     = [ mk' mod "=="        "eqV"
       , mk' mod "/="        "neqV"
       , mk' mod "<="        "leV"
       , mk' mod "<"         "ltV"
       , mk' mod ">="        "geV"
       , mk' mod ">"         "gtV"
       , mk' mod "min"       "minV"
       , mk' mod "max"       "maxV"
       , mk' mod "minimumP"  "minimumPA"
       , mk' mod "maximumP"  "maximumPA"
       , mk' mod "minIndexP" "minIndexPA"
       , mk' mod "maxIndexP" "maxIndexPA"
       ]

    vars_Num mod 
     = [ mk' mod "+"        "plusV"
       , mk' mod "-"        "minusV"
       , mk' mod "*"        "multV"
       , mk' mod "negate"   "negateV"
       , mk' mod "abs"      "absV"
       , mk' mod "sumP"     "sumPA"
       , mk' mod "productP" "productPA"
       ]

    vars_Fractional mod 
     = [ mk' mod "/"     "divideV"
       , mk' mod "recip" "recipV"
       ]

    vars_Floating mod 
     = [ mk' mod "pi"      "pi"
       , mk' mod "exp"     "expV"
       , mk' mod "sqrt"    "sqrtV"
       , mk' mod "log"     "logV"
       , mk' mod "sin"     "sinV"
       , mk' mod "tan"     "tanV"
       , mk' mod "cos"     "cosV"
       , mk' mod "asin"    "asinV"
       , mk' mod "atan"    "atanV"
       , mk' mod "acos"    "acosV"
       , mk' mod "sinh"    "sinhV"
       , mk' mod "tanh"    "tanhV"
       , mk' mod "cosh"    "coshV"
       , mk' mod "asinh"   "asinhV"
       , mk' mod "atanh"   "atanhV"
       , mk' mod "acosh"   "acoshV"
       , mk' mod "**"      "powV"
       , mk' mod "logBase" "logBaseV"
       ]

    vars_RealFrac mod
     = [ mk' mod "fromInt"  "fromIntV"
       , mk' mod "truncate" "truncateV"
       , mk' mod "round"    "roundV"
       , mk' mod "ceiling"  "ceilingV"
       , mk' mod "floor"    "floorV"
       ]


preludeScalars :: Modules -> [(Module, FastString)]
preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
                        , dph_Prelude_Word8  = dph_Prelude_Word8
                        , dph_Prelude_Double = dph_Prelude_Double
                        })
  = [ mk dph_Prelude_Int "div"
    , mk dph_Prelude_Int "mod"
    , mk dph_Prelude_Int "sqrt"
    ]
    ++ scalars_Ord dph_Prelude_Int
    ++ scalars_Num dph_Prelude_Int

    ++ scalars_Ord dph_Prelude_Word8
    ++ scalars_Num dph_Prelude_Word8
    ++
    [ mk dph_Prelude_Word8 "div"
    , mk dph_Prelude_Word8 "mod"
    , mk dph_Prelude_Word8 "fromInt"
    , mk dph_Prelude_Word8 "toInt"
    ]

    ++ scalars_Ord dph_Prelude_Double
    ++ scalars_Num dph_Prelude_Double
    ++ scalars_Fractional dph_Prelude_Double
    ++ scalars_Floating dph_Prelude_Double
    ++ scalars_RealFrac dph_Prelude_Double
  where
    mk mod s = (mod, fsLit s)

    scalars_Ord mod 
     = [ mk mod "=="
       , mk mod "/="
       , mk mod "<="
       , mk mod "<"
       , mk mod ">="
       , mk mod ">"
       , mk mod "min"
       , mk mod "max"
       ]

    scalars_Num mod 
     = [ mk mod "+"
       , mk mod "-"
       , mk mod "*"
       , mk mod "negate"
       , mk mod "abs"
       ]

    scalars_Fractional mod 
     = [ mk mod "/"
       , mk mod "recip"
       ]

    scalars_Floating mod 
     = [ mk mod "pi"
       , mk mod "exp"
       , mk mod "sqrt"
       , mk mod "log"
       , mk mod "sin"
       , mk mod "tan"
       , mk mod "cos"
       , mk mod "asin"
       , mk mod "atan"
       , mk mod "acos"
       , mk mod "sinh"
       , mk mod "tanh"
       , mk mod "cosh"
       , mk mod "asinh"
       , mk mod "atanh"
       , mk mod "acosh"
       , mk mod "**"
       , mk mod "logBase"
       ]

    scalars_RealFrac mod 
     = [ mk mod "fromInt"
       , mk mod "truncate"
       , mk mod "round"
       , mk mod "ceiling"
       , mk mod "floor"
       ]