module Vectorise.Generic.PADict
( buildPADict
) where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Generic.Description
import Vectorise.Generic.PAMethods ( buildPAScAndMethods )
import Vectorise.Utils
import BasicTypes
import CoreSyn
import CoreUtils
import CoreUnfold
import Module
import TyCon
import CoAxiom
import Type
import Id
import Var
import Name
import FastString
buildPADict
:: TyCon
-> CoAxiom Unbranched
-> TyCon
-> TyCon
-> SumRepr
-> VM Var
buildPADict vect_tc prepr_ax pdata_tc pdatas_tc repr
= polyAbstract tvs $ \args ->
do { mod <- liftDs getModule
; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name
; let mk_super_ty = do { r <- mkPReprType inst_ty
; pr_cls <- builtin prClass
; return $ mkClassPred pr_cls [r]
}
; super_tys <- sequence [mk_super_ty | not (null tvs)]
; super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
; let val_args = super_args ++ args
all_args = tvs ++ val_args
; super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_ax [] | null tvs]
; paMethodBuilders <- buildPAScAndMethods
; method_ids <- mapM (method val_args dfun_name) paMethodBuilders
; pa_dc <- builtin paDataCon
; let dict = mkLams all_args (mkConApp pa_dc con_args)
con_args = Type inst_ty
: map Var super_args
++ super_consts
++ map (method_call val_args) method_ids
; pa_cls <- builtin paClass
; let dfun_ty = mkForAllTys tvs
$ mkFunTys (map varType val_args)
(mkClassPred pa_cls [inst_ty])
; raw_dfun <- newExportedVar dfun_name dfun_ty
; let dfun_unf = mkDFunUnfolding all_args pa_dc con_args
dfun = raw_dfun `setIdUnfolding` dfun_unf
`setInlinePragma` dfunInlinePragma
; hoistBinding dfun dict
; return dfun
}
where
tvs = tyConTyVars vect_tc
arg_tys = mkTyVarTys tvs
inst_ty = mkTyConApp vect_tc arg_tys
vect_tc_name = getName vect_tc
method args dfun_name (name, build)
= localV
$ do expr <- build vect_tc prepr_ax pdata_tc pdatas_tc repr
let body = mkLams (tvs ++ args) expr
raw_var <- newExportedVar (method_name dfun_name name) (exprType body)
let var = raw_var
`setIdUnfolding` mkInlineUnfolding (Just (length args)) body
`setInlinePragma` alwaysInlinePragma
hoistBinding var body
return var
method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
method_name dfun_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)