module Vectorise.Utils.PADict (
paDictArgType,
paDictOfType,
paMethod,
prDictOfReprType,
prDictOfPReprInstTyCon
) where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils.Base
import CoreSyn
import CoreUtils
import FamInstEnv
import Coercion
import Type
import TyCoRep
import TyCon
import CoAxiom
import Var
import Outputable
import DynFlags
import FastString
import Control.Monad
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (mkTyVarTy tv) (tyVarKind tv)
where
go ty (ForAllTy (Anon k1) k2)
= do
tv <- if isCoercionType k1
then newCoVar (fsLit "c") k1
else newTyVar (fsLit "a") k1
mty1 <- go (mkTyVarTy tv) k1
case mty1 of
Just ty1 -> do
mty2 <- go (mkAppTy ty (mkTyVarTy tv)) k2
return $ fmap (mkNamedForAllTy tv Invisible . mkFunTy ty1) mty2
Nothing -> go ty k2
go ty k
| isLiftedTypeKind k
= do
pa_cls <- builtin paClass
return $ Just $ mkClassPred pa_cls [ty]
go _ _ = return Nothing
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty
= paDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn
= paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do
{ dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
(ppr tv <+> text "in" <+> ppr ty)
$ lookupTyVarPA tv
; dicts <- mapM paDictOfType ty_args
; return $ dfun `mkTyApps` ty_args `mkApps` dicts
}
paDictOfTyApp (TyConApp tc []) ty_args
= do
{ dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
$ lookupTyConPA tc
; super <- super_dict tc ty_args
; dicts <- mapM paDictOfType ty_args
; return $ Var dfun `mkTyApps` ty_args `mkApps` super `mkApps` dicts
}
where
noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"
super_dict _ [] = return []
super_dict tycon ty_args
= do
{ pr <- prDictOfPReprInst (TyConApp tycon ty_args)
; return [pr]
}
paDictOfTyApp _ _ = getDynFlags >>= failure
failure dflags = cantVectorise dflags "Can't construct PA dictionary for type" (ppr ty)
paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
paMethod _ query ty
| Just tycon <- splitPrimTyCon ty
= liftM Var $ builtin (query tycon)
paMethod method _ ty
= do
{ fn <- builtin method
; dict <- paDictOfType ty
; return $ mkApps (Var fn) [Type ty, dict]
}
prDictOfPReprInst :: Type -> VM CoreExpr
prDictOfPReprInst ty
= do
{ (FamInstMatch { fim_instance = prepr_fam, fim_tys = prepr_args })
<- preprFamInst ty
; prDictOfPReprInstTyCon ty (famInstAxiom prepr_fam) prepr_args
}
prDictOfPReprInstTyCon :: Type -> CoAxiom Unbranched -> [Type] -> VM CoreExpr
prDictOfPReprInstTyCon _ty prepr_ax prepr_args
= do
let rhs = mkUnbranchedAxInstRHS prepr_ax prepr_args []
dict <- prDictOfReprType' rhs
pr_co <- mkBuiltinCo prTyCon
let co = mkAppCo pr_co
$ mkSymCo
$ mkUnbranchedAxInstCo Nominal prepr_ax prepr_args []
return $ mkCast dict co
prDictOfReprType :: Type -> VM CoreExpr
prDictOfReprType ty
| Just (tycon, tyargs) <- splitTyConApp_maybe ty
= do
prepr <- builtin preprTyCon
if tycon == prepr
then do
let [ty'] = tyargs
pa <- paDictOfType ty'
sel <- builtin paPRSel
return $ Var sel `App` Type ty' `App` pa
else do
dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
lookupTyConPR tycon
prDFunApply dfun tyargs
| otherwise
= do
pa <- paDictOfType ty
prsel <- builtin paPRSel
return $ Var prsel `mkApps` [Type ty, pa]
prDictOfReprType' :: Type -> VM CoreExpr
prDictOfReprType' ty = prDictOfReprType ty `orElseV`
do dflags <- getDynFlags
cantVectorise dflags "No PR dictionary for representation type"
(ppr ty)
prDFunApply :: Var -> [Type] -> VM CoreExpr
prDFunApply dfun tys
| Just [] <- ctxs
= return $ Var dfun `mkTyApps` tys
| Just tycons <- ctxs
, length tycons == length tys
= do
pa <- builtin paTyCon
pr <- builtin prTyCon
dflags <- getDynFlags
args <- zipWithM (dictionary dflags pa pr) tys tycons
return $ Var dfun `mkTyApps` tys `mkApps` args
| otherwise = do dflags <- getDynFlags
invalid dflags
where
ctxs = fmap (map fst)
$ sequence
$ map splitTyConApp_maybe
$ fst
$ splitFunTys
$ snd
$ splitForAllTys
$ varType dfun
dictionary dflags pa pr ty tycon
| tycon == pa = paDictOfType ty
| tycon == pr = prDictOfReprType ty
| otherwise = invalid dflags
invalid dflags = cantVectorise dflags "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)