module Vectorise.Utils (
module Vectorise.Utils.Base,
module Vectorise.Utils.Closure,
module Vectorise.Utils.Hoisting,
module Vectorise.Utils.PADict,
module Vectorise.Utils.Poly,
collectAnnTypeArgs,
collectAnnDictArgs,
collectAnnTypeBinders,
collectAnnValBinders,
isAnnTypeArg,
replicatePD, emptyPD, packByTagPD,
combinePD, liftPD,
zipScalars, scalarClosure,
newLocalVar
) where
import Vectorise.Utils.Base
import Vectorise.Utils.Closure
import Vectorise.Utils.Hoisting
import Vectorise.Utils.PADict
import Vectorise.Utils.Poly
import Vectorise.Monad
import Vectorise.Builtins
import CoreSyn
import CoreUtils
import Id
import Type
import Control.Monad
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
where
go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
go e tys = (e, tys)
collectAnnDictArgs :: AnnExpr Var ann -> (AnnExpr Var ann, [AnnExpr Var ann])
collectAnnDictArgs expr = go expr []
where
go e@(_, AnnApp f arg) dicts
| isPredTy . exprType . deAnnotate $ arg = go f (arg : dicts)
| otherwise = (e, dicts)
go e dicts = (e, dicts)
collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
where
go bs (_, AnnLam b e) | isTyVar b = go (b : bs) e
go bs e = (reverse bs, e)
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
where
go bs (_, AnnLam b e) | isId b
&& (not . isPredTy . idType $ b) = go (b : bs) e
go bs e = (reverse bs, e)
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType _) = True
isAnnTypeArg _ = False
emptyPD :: Type -> VM CoreExpr
emptyPD = paMethod emptyPDVar emptyPD_PrimVar
replicatePD :: CoreExpr
-> CoreExpr
-> VM CoreExpr
replicatePD len x
= liftM (`mkApps` [len,x])
$ paMethod replicatePDVar replicatePD_PrimVar (exprType x)
packByTagPD :: Type
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> CoreExpr
-> VM CoreExpr
packByTagPD ty xs len tags t
= liftM (`mkApps` [xs, len, tags, t])
(paMethod packByTagPDVar packByTagPD_PrimVar ty)
combinePD :: Type
-> CoreExpr
-> CoreExpr
-> [CoreExpr]
-> VM CoreExpr
combinePD ty len sel xs
= liftM (`mkApps` (len : sel : xs))
(paMethod (combinePDVar n) (combinePD_PrimVar n) ty)
where
n = length xs
liftPD :: CoreExpr -> VM CoreExpr
liftPD x
= do
lc <- builtin liftingContext
replicatePD (Var lc) x
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
= do
scalar <- builtin scalarClass
(dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
zipf <- builtin (scalarZip $ length arg_tys)
return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
where
ty_args = arg_tys ++ [res_ty]
scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
= do
ctr <- builtin (closureCtrFun $ length arg_tys)
pas <- mapM paDictOfType (init arg_tys)
return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
`mkApps` (pas ++ [scalar_fun, array_fun])