module Vectorise.Generic.PADict
( buildPADict
) where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Generic.Description
import Vectorise.Generic.PAMethods ( buildPAScAndMethods )
import Vectorise.Utils
import BasicTypes
import CoreSyn
import CoreUtils
import CoreUnfold
import DsMonad
import TyCon
import Type
import Id
import Var
import Name
buildPADict
:: TyCon
-> TyCon
-> TyCon
-> TyCon
-> SumRepr
-> VM Var
buildPADict vect_tc prepr_tc pdata_tc pdatas_tc repr
= polyAbstract tvs $ \args ->
do { mod <- liftDs getModuleDs
; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name
; paMethodBuilders <- buildPAScAndMethods
; method_ids <- mapM (method args dfun_name) paMethodBuilders
; pa_dc <- builtin paDataCon
; let dict = mkLams (tvs ++ args)
$ mkConApp pa_dc
$ Type inst_ty
: map (method_call args) method_ids
; pa_cls <- builtin paClass
; let dfun_ty = mkForAllTys tvs
$ mkFunTys (map varType args)
(mkClassPred pa_cls [inst_ty])
; raw_dfun <- newExportedVar dfun_name dfun_ty
; let dfun_unf = mkDFunUnfolding dfun_ty $
map Var method_ids
dfun = raw_dfun `setIdUnfolding` dfun_unf
`setInlinePragma` dfunInlinePragma
; hoistBinding dfun dict
; return dfun
}
where
tvs = tyConTyVars vect_tc
arg_tys = mkTyVarTys tvs
inst_ty = mkTyConApp vect_tc arg_tys
vect_tc_name = getName vect_tc
method args dfun_name (name, build)
= localV
$ do expr <- build vect_tc prepr_tc pdata_tc pdatas_tc repr
let body = mkLams (tvs ++ args) expr
raw_var <- newExportedVar (method_name dfun_name name) (exprType body)
let var = raw_var
`setIdUnfolding` mkInlineUnfolding (Just (length args)) body
`setInlinePragma` alwaysInlinePragma
hoistBinding var body
return var
method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
method_name dfun_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)