module Vectorise.Utils.PADict (
paDictArgType,
paDictOfType,
paMethod,
prDictOfReprType,
prDictOfPReprInstTyCon
) where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils.Base
import CoreSyn
import CoreUtils
import Coercion
import Type
import TypeRep
import TyCon
import Var
import Outputable
import FastString
import Control.Monad
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
where
go ty (FunTy k1 k2)
= do
tv <- newTyVar (fsLit "a") k1
mty1 <- go (TyVarTy tv) k1
case mty1 of
Just ty1 -> do
mty2 <- go (AppTy ty (TyVarTy tv)) k2
return $ fmap (ForAllTy tv . FunTy ty1) mty2
Nothing -> go ty k2
go ty k
| isLiftedTypeKind k
= do
pa_cls <- builtin paClass
return $ Just $ mkClassPred pa_cls [ty]
go _ _ = return Nothing
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty
= paDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn
= paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
(ppr tv <+> text "in" <+> ppr ty)
$ lookupTyVarPA tv
dicts <- mapM paDictOfType ty_args
return $ dfun `mkTyApps` ty_args `mkApps` dicts
paDictOfTyApp (TyConApp tc []) ty_args
= do
dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
$ lookupTyConPA tc
dicts <- mapM paDictOfType ty_args
return $ Var dfun `mkTyApps` ty_args `mkApps` dicts
where
noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"
paDictOfTyApp _ _ = failure
failure = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
paMethod _ query ty
| Just tycon <- splitPrimTyCon ty
= liftM Var $ builtin (query tycon)
paMethod method _ ty
= do
fn <- builtin method
dict <- paDictOfType ty
return $ mkApps (Var fn) [Type ty, dict]
prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr
prDictOfPReprInstTyCon ty prepr_tc prepr_args
| Just rhs <- coreView (mkTyConApp prepr_tc prepr_args)
= do
dict <- prDictOfReprType' rhs
pr_co <- mkBuiltinCo prTyCon
let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
let co = mkAppCo pr_co
$ mkSymCo
$ mkAxInstCo arg_co prepr_args
return $ mkCast dict co
| otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty)
prDictOfReprType :: Type -> VM CoreExpr
prDictOfReprType ty
| Just (tycon, tyargs) <- splitTyConApp_maybe ty
= do
prepr <- builtin preprTyCon
if tycon == prepr
then do
let [ty'] = tyargs
pa <- paDictOfType ty'
sel <- builtin paPRSel
return $ Var sel `App` Type ty' `App` pa
else do
dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
lookupTyConPR tycon
prDFunApply dfun tyargs
| otherwise
= do
pa <- paDictOfType ty
prsel <- builtin paPRSel
return $ Var prsel `mkApps` [Type ty, pa]
prDictOfReprType' :: Type -> VM CoreExpr
prDictOfReprType' ty = prDictOfReprType ty `orElseV`
cantVectorise "No PR dictionary for representation type"
(ppr ty)
prDFunApply :: Var -> [Type] -> VM CoreExpr
prDFunApply dfun tys
| Just [] <- ctxs
= return $ Var dfun `mkTyApps` tys
| Just tycons <- ctxs
, length tycons == length tys
= do
pa <- builtin paTyCon
pr <- builtin prTyCon
args <- zipWithM (dictionary pa pr) tys tycons
return $ Var dfun `mkTyApps` tys `mkApps` args
| otherwise = invalid
where
ctxs = fmap (map fst)
$ sequence
$ map splitTyConApp_maybe
$ fst
$ splitFunTys
$ snd
$ splitForAllTys
$ varType dfun
dictionary pa pr ty tycon
| tycon == pa = paDictOfType ty
| tycon == pr = prDictOfReprType ty
| otherwise = invalid
invalid = cantVectorise "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)