module Vectorise.Utils.PADict (
paDictArgType,
paDictOfType,
paMethod,
prDictOfReprType,
prDictOfPReprInstTyCon
) where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils.Base
import CoreSyn
import CoreUtils
import FamInstEnv
import Coercion
import Type
import TypeRep
import TyCon
import CoAxiom
import Var
import Outputable
import DynFlags
import FastString
import Control.Monad
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
where
go ty (FunTy k1 k2)
= do
tv <- newTyVar (fsLit "a") k1
mty1 <- go (TyVarTy tv) k1
case mty1 of
Just ty1 -> do
mty2 <- go (AppTy ty (TyVarTy tv)) k2
return $ fmap (ForAllTy tv . FunTy ty1) mty2
Nothing -> go ty k2
go ty k
| isLiftedTypeKind k
= do
pa_cls <- builtin paClass
return $ Just $ mkClassPred pa_cls [ty]
go _ _ = return Nothing
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty
= paDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn
= paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do
{ dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
(ppr tv <+> text "in" <+> ppr ty)
$ lookupTyVarPA tv
; dicts <- mapM paDictOfType ty_args
; return $ dfun `mkTyApps` ty_args `mkApps` dicts
}
paDictOfTyApp (TyConApp tc []) ty_args
= do
{ dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
$ lookupTyConPA tc
; super <- super_dict tc ty_args
; dicts <- mapM paDictOfType ty_args
; return $ Var dfun `mkTyApps` ty_args `mkApps` super `mkApps` dicts
}
where
noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"
super_dict _ [] = return []
super_dict tycon ty_args
= do
{ pr <- prDictOfPReprInst (TyConApp tycon ty_args)
; return [pr]
}
paDictOfTyApp _ _ = getDynFlags >>= failure
failure dflags = cantVectorise dflags "Can't construct PA dictionary for type" (ppr ty)
paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
paMethod _ query ty
| Just tycon <- splitPrimTyCon ty
= liftM Var $ builtin (query tycon)
paMethod method _ ty
= do
{ fn <- builtin method
; dict <- paDictOfType ty
; return $ mkApps (Var fn) [Type ty, dict]
}
prDictOfPReprInst :: Type -> VM CoreExpr
prDictOfPReprInst ty
= do
{ (FamInstMatch { fim_instance = prepr_fam, fim_tys = prepr_args }) <- preprSynTyCon ty
; prDictOfPReprInstTyCon ty (famInstAxiom prepr_fam) prepr_args
}
prDictOfPReprInstTyCon :: Type -> CoAxiom Unbranched -> [Type] -> VM CoreExpr
prDictOfPReprInstTyCon _ty prepr_ax prepr_args
= do
let rhs = mkUnbranchedAxInstRHS prepr_ax prepr_args
dict <- prDictOfReprType' rhs
pr_co <- mkBuiltinCo prTyCon
let co = mkAppCo pr_co
$ mkSymCo
$ mkUnbranchedAxInstCo Nominal prepr_ax prepr_args
return $ mkCast dict co
prDictOfReprType :: Type -> VM CoreExpr
prDictOfReprType ty
| Just (tycon, tyargs) <- splitTyConApp_maybe ty
= do
prepr <- builtin preprTyCon
if tycon == prepr
then do
let [ty'] = tyargs
pa <- paDictOfType ty'
sel <- builtin paPRSel
return $ Var sel `App` Type ty' `App` pa
else do
dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
lookupTyConPR tycon
prDFunApply dfun tyargs
| otherwise
= do
pa <- paDictOfType ty
prsel <- builtin paPRSel
return $ Var prsel `mkApps` [Type ty, pa]
prDictOfReprType' :: Type -> VM CoreExpr
prDictOfReprType' ty = prDictOfReprType ty `orElseV`
do dflags <- getDynFlags
cantVectorise dflags "No PR dictionary for representation type"
(ppr ty)
prDFunApply :: Var -> [Type] -> VM CoreExpr
prDFunApply dfun tys
| Just [] <- ctxs
= return $ Var dfun `mkTyApps` tys
| Just tycons <- ctxs
, length tycons == length tys
= do
pa <- builtin paTyCon
pr <- builtin prTyCon
dflags <- getDynFlags
args <- zipWithM (dictionary dflags pa pr) tys tycons
return $ Var dfun `mkTyApps` tys `mkApps` args
| otherwise = do dflags <- getDynFlags
invalid dflags
where
ctxs = fmap (map fst)
$ sequence
$ map splitTyConApp_maybe
$ fst
$ splitFunTys
$ snd
$ splitForAllTys
$ varType dfun
dictionary dflags pa pr ty tycon
| tycon == pa = paDictOfType ty
| tycon == pr = prDictOfReprType ty
| otherwise = invalid dflags
invalid dflags = cantVectorise dflags "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)