module Vectorise.Exp
(
vectPolyExpr
, vectDictExpr
, vectScalarFun
, vectScalarDFun
)
where
#include "HsVersions.h"
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Convert
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreUtils
import MkCore
import CoreSyn
import CoreFVs
import Class
import DataCon
import TyCon
import TcType
import Type
import PrelNames
import NameSet
import Var
import VarEnv
import VarSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
import TysWiredIn
import TysPrim
import Outputable
import FastString
import Control.Monad
import Control.Applicative
import Data.Maybe
import Data.List
vectPolyExpr :: Bool
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr)
= do { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
; return (inline, isScalarFn, vTick tickish expr')
}
vectPolyExpr loop_breaker recFns expr
= do { arity <- polyArity tvs
; polyAbstract tvs $ \args -> do
{ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
} }
where
(tvs, mono) = collectAnnTypeBinders expr
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnVar v)
= vectVar v
vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
vectExpr e@(_, AnnLam bndr _)
| isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
= do { (vty, lty) <- vectAndLiftType ty
; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
}
where
err' = deAnnotate err
vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectPolyApp e
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
| Just con <- isDataConId_maybe v
, is_special_con con
= do
let vexpr = App (Var v) (Lit lit)
lexpr <- liftPD vexpr
return (vexpr, lexpr)
where
is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
vectExpr e@(_, AnnApp fn arg)
| isPredTy arg_ty
= vectPolyApp e
| otherwise
= do {
; varg_ty <- vectType arg_ty
; vres_ty <- vectType res_ty
; vfn <- vectExpr fn
; varg <- vectExpr arg
; mkClosureApp varg_ty vres_ty vfn varg
}
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
| otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vectExpr (_, AnnLet (AnnRec bs) body)
= do
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWithM vect_rhs bndrs rhss)
(vectExpr body)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs = localV
. inBind bndr
. liftM (\(_,_,z)->z)
$ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr (_, AnnTick tickish expr)
= liftM (vTick tickish) (vectExpr expr)
vectExpr (_, AnnType ty)
= liftM vType (vectType ty)
vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
vectFnExpr :: Bool
-> Bool
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body)
| isId bndr
&& isPredTy (idType bndr)
= do { vBndr <- vectBndr bndr
; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body
; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
}
| isId bndr
= mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
`orElseV`
mark inlineMe False (vectLam inline loop_breaker expr)
vectFnExpr _ _ _ e
= mark DontInline False $ vectExpr e
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
vectPolyApp :: CoreExprWithFVs -> VM VExpr
vectPolyApp e0
= case e4 of
(_, AnnVar var)
-> do {
; vVar <- lookupVar var
; traceVt "vectPolyApp of" (ppr var)
; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
; vTysOuter <- mapM vectType tysOuter
; vTysInner <- mapM vectType tysInner
; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
; case vVar of
Local (vv, lv)
-> do { MASSERT( null dictsInner )
; traceVt " LOCAL" (text "")
; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
}
Global vv
| isDictComp var
-> do {
; ve <- if null dictsInner
then
return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
else
reconstructOuter
(Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
; traceVt " GLOBAL (dict):" (ppr ve)
; vectConst ve
}
| otherwise
-> do { MASSERT( null dictsInner )
; ve <- reconstructOuter (Var vv)
; traceVt " GLOBAL (non-dict):" (ppr ve)
; vectConst ve
}
}
_ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
where
(e1, dictsOuter) = collectAnnDictArgs e0
(e2, tysOuter) = collectAnnTypeArgs e1
(e3, dictsInner) = collectAnnDictArgs e2
(e4, tysInner) = collectAnnTypeArgs e3
isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
vectDictExpr :: CoreExpr -> VM CoreExpr
vectDictExpr (Var var)
= do { mb_scope <- lookupVar_maybe var
; case mb_scope of
Nothing -> return $ Var var
Just (Local (vVar, _)) -> return $ Var vVar
Just (Global vVar) -> return $ Var vVar
}
vectDictExpr (Lit lit)
= pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
vectDictExpr (Lam bndr e)
= Lam bndr <$> vectDictExpr e
vectDictExpr (App fn arg)
= App <$> vectDictExpr fn <*> vectDictExpr arg
vectDictExpr (Case e bndr ty alts)
= Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
where
vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
where
dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
vectDictAltCon (LitAlt lit) = return $ LitAlt lit
vectDictAltCon DEFAULT = return DEFAULT
vectDictExpr (Let bnd body)
= Let <$> vectDictBind bnd <*> vectDictExpr body
where
vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
vectDictExpr e@(Cast _e _coe)
= pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
vectDictExpr (Tick tickish e)
= Tick tickish <$> vectDictExpr e
vectDictExpr (Type ty)
= Type <$> vectType ty
vectDictExpr (Coercion coe)
= pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
vectScalarFun :: Bool
-> [Var]
-> CoreExpr
-> VM VExpr
vectScalarFun forceScalar recFns expr
= do { gscalarVars <- globalScalarVars
; scalarTyCons <- globalScalarTyCons
; let scalarVars = gscalarVars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
; onlyIfV (ptext (sLit "not a scalar function"))
(forceScalar
||
all is_primitive_ty arg_tys
&& is_primitive_ty res_ty
&& is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
&& uses scalarVars expr
&& length arg_tys <= mAX_DPH_SCALAR_ARGS)
$ mkScalarFun arg_tys res_ty expr
}
where
is_primitive_ty ty
| isPredTy ty
= True
| Just (tycon, _) <- splitTyConApp_maybe ty
= tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
| otherwise
= False
is_scalar_ty scalarTyCons ty
| isPredTy ty
= True
| Just (tycon, _) <- splitTyConApp_maybe ty
= tyConName tycon `elemNameSet` scalarTyCons
| otherwise
= False
is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars
is_scalar _scalars _isScalarTC (Lit _) = True
is_scalar scalars isScalarTC e@(App e1 e2)
| maybe_parr_ty (exprType e) = False
| otherwise = is_scalar scalars isScalarTC e1 &&
is_scalar scalars isScalarTC e2
is_scalar scalars isScalarTC (Lam var body)
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var)
isScalarTC body
is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar &&
is_scalar scalars' isScalarTC body
where
(bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
is_scalar scalars isScalarTC (Case e var ty alts)
| isScalarTC ty = is_scalar scalars' isScalarTC e &&
all (is_scalar_alt scalars' isScalarTC) alts
| otherwise = False
where
scalars' = scalars `extendVarSet` var
is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e
is_scalar scalars isScalarTC (Tick _ e ) = is_scalar scalars isScalarTC e
is_scalar _scalars _isScalarTC (Type {}) = True
is_scalar _scalars _isScalarTC (Coercion {}) = True
is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e,
scalars `extendVarSet` var)
is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es,
scalars')
where
(vars, es) = unzip bnds
scalars' = scalars `extendVarSetList` vars
is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars)
isScalarTCs e
maybe_parr_ty :: Type -> Bool
maybe_parr_ty ty
| Just ty' <- coreView ty = maybe_parr_ty ty'
| Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
maybe_parr_ty _ = False
uses funs (Var v) = v `elemVarSet` funs
uses funs (App e1 e2) = uses funs e1 || uses funs e2
uses funs (Lam b body) = uses (funs `extendVarSet` b) body
uses funs (Let (NonRec _b letExpr) body)
= uses funs letExpr || uses funs body
uses funs (Case e _eId _ty alts)
= uses funs e || any (uses_alt funs) alts
uses _ _ = False
uses_alt funs (_, _bs, e) = uses funs e
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
| isPredTy res_ty
= do { vExpr <- vectDictExpr expr
; return (vExpr, unused)
}
| otherwise
= do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
; fn_var <- hoistExpr (fsLit "fn") expr DontInline
; zipf <- zipScalars arg_tys res_ty
; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
; clo_var <- hoistExpr (fsLit "clo") clo DontInline
; lclo <- liftPD (Var clo_var)
; return (Var clo_var, lclo)
}
where
unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
vectScalarDFun :: Var
-> [Var]
-> VM CoreExpr
vectScalarDFun var recFns
= do {
; mapM_ defLocalTyVar tvs
; vTheta <- mapM vectType theta
; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
; let vThetaVars = varsToCoreExprs vThetaBndr
; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
; thetaExprs <- zipWithM unVectDict theta vThetaVars
; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
selIds
; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun True recFns e) scsOps
; Just vDataCon <- lookupDataCon dataCon
; vTys <- mapM vectType tys
; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
; return $ mkLams (tvs ++ vThetaBndr) vBody
}
where
ty = varType var
(tvs, theta, pty) = tcSplitSigmaTy ty
(cls, tys) = tcSplitDFunHead pty
selIds = classAllSelIds cls
dataCon = classDataCon cls
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e
= do { vTys <- mapM vectType tys
; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
; scOps <- zipWithM fromVect methTys meths
; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
}
where
(tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
cls = case tyConClass_maybe tycon of
Just cls -> cls
Nothing -> panic "Vectorise.Exp.unVectDict: no class"
selIds = classAllSelIds cls
vectLam :: Bool
-> Bool
-> CoreExprWithFVs
-> VM VExpr
vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
= do { let (bndrs, body) = collectAnnValBinders expr
; tyvars <- localTyVars
; vfvs <- readLEnv $ \env ->
[ (var, fromJust mb_vv)
| var <- varSetElems fvs
, let mb_vv = lookupVarEnv (local_vars env) var
, isJust mb_vv
]
; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
(_fvs_dict, vfvs_dict) = unzip vvs_dict
(fvs_nondict, vfvs_nondict) = unzip vvs_nondict
; arg_tys <- mapM (vectType . idType) bndrs
; res_ty <- vectType (exprType $ deAnnotate body)
; let arity = length fvs_nondict + length bndrs
vfvs_dict' = map vectorised vfvs_dict
; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do {
; lc <- builtin liftingContext
; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body)
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
maybe_inline n | inline = Inline n
| otherwise = DontInline
break_loop lc ty (ve, le)
| loop_breaker
= do { empty <- emptyPD ty
; lty <- mkPDataType ty
; return (ve, mkWildCase (Var lc) intPrimTy lty
[(DEFAULT, [], le),
(LitAlt (mkMachInt 0), [], empty)])
}
| otherwise = return (ve, le)
vectLam _ _ _ = panic "vectLam"
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
-> [(AltCon, [Var], CoreExprWithFVs)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
mk_wild_case expr ty dc bndrs body
= mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
vect_tc <- maybeV tyConErr (lookupTyCon tycon)
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
sel_ty <- builtin (selTy arity)
sel_bndr <- newLocalVar (fsLit "sel") sel_ty
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
vdummy <- newDummyVar (exprType vect_scrut)
ldummy <- newDummyVar (exprType lift_scrut)
let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
lc <- builtin liftingContext
lbody <- combinePD vty (Var lc) sel lift_bodies
let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
cmp DEFAULT DEFAULT = EQ
cmp DEFAULT _ = LT
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
= do
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
fvs = freeVarsOf body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
elems <- builtin (selElements arity ntag)
(vbndrs, vbody)
<- vectBndrsIn bndrs
. localV
$ do
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
(ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
let (vect_bndrs, lift_bndrs) = unzip vbndrs
return (vect_dc, vect_bndrs, lift_bndrs, vbody)
where
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
pack_var len tags t v
= do
r <- lookupVar v
case r of
Local (vv, lv) ->
do
lv' <- cloneVar lv
expr <- packByTagPD (idType vv) (Var lv) len tags t
updLEnv (\env -> env { local_vars = extendVarEnv
(local_vars env) v (vv, lv') })
return [(NonRec lv' expr)]
_ -> return []