module Vectorise.Exp (
vectPolyExpr,
vectScalarFun
) where
#include "HsVersions.h"
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreSyn
import CoreUtils
import MkCore
import CoreFVs
import DataCon
import TyCon
import Type
import Var
import VarEnv
import VarSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
import TysWiredIn
import TysPrim
import Outputable
import FastString
import Control.Monad
import Data.List
vectPolyExpr :: Bool
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
= do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
(tvs, mono) = collectAnnTypeBinders expr
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
= liftM vType (vectType ty)
vectExpr (_, AnnVar v)
= vectVar v
vectExpr (_, AnnLit lit)
= vectLiteral lit
vectExpr (_, AnnNote note expr)
= liftM (vNote note) (vectExpr expr)
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
= do { (vty, lty) <- vectAndLiftType ty
; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
}
where
err' = deAnnotate err
vectExpr e@(_, AnnApp _ arg)
| isAnnTypeArg arg
= vectTyAppExpr fn tys
where
(fn, tys) = collectAnnTypeArgs e
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
| Just con <- isDataConId_maybe v
, is_special_con con
= do
let vexpr = App (Var v) (Lit lit)
lexpr <- liftPD vexpr
return (vexpr, lexpr)
where
is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
vectExpr (_, AnnApp fn arg)
= do
arg_ty' <- vectType arg_ty
res_ty' <- vectType res_ty
fn' <- vectExpr fn
arg' <- vectExpr arg
mkClosureApp arg_ty' res_ty' fn' arg'
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts)
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
| otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vectExpr (_, AnnLet (AnnRec bs) body)
= do
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWithM vect_rhs bndrs rhss)
(vectExpr body)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs = localV
. inBind bndr
. liftM (\(_,_,z)->z)
$ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
| isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
vectFnExpr :: Bool
-> Bool
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
| isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
`orElseV`
mark inlineMe False (vectLam inline loop_breaker expr)
vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
vectScalarFun :: Bool
-> [Var]
-> CoreExpr
-> VM VExpr
vectScalarFun forceScalar recFns expr
= do { gscalars <- globalScalars
; let scalars = gscalars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
; onlyIfV (forceScalar
||
all is_prim_ty arg_tys
&& is_prim_ty res_ty
&& is_scalar scalars expr
&& uses scalars expr)
$ mkScalarFun arg_tys res_ty expr
}
where
is_prim_ty ty
| Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
| otherwise = False
is_scalar :: VarSet -> CoreExpr -> Bool
is_scalar scalars (Var v) = v `elemVarSet` scalars
is_scalar _scalars (Lit _) = True
is_scalar scalars e@(App e1 e2)
| maybe_parr_ty (exprType e) = False
| otherwise = is_scalar scalars e1 && is_scalar scalars e2
is_scalar scalars (Lam var body)
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var) body
is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body
where
(bindsAreScalar, scalars') = is_scalar_bind scalars bind
is_scalar scalars (Case e var ty alts)
| is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts
| otherwise = False
where
scalars' = scalars `extendVarSet` var
is_scalar scalars (Cast e _coe) = is_scalar scalars e
is_scalar scalars (Note _ e ) = is_scalar scalars e
is_scalar _scalars (Type {}) = True
is_scalar _scalars (Coercion {}) = True
is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars')
where
(vars, es) = unzip bnds
scalars' = scalars `extendVarSetList` vars
is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
maybe_parr_ty :: Type -> Bool
maybe_parr_ty ty
| Just ty' <- coreView ty = maybe_parr_ty ty'
| Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
maybe_parr_ty _ = False
uses funs (Var v) = v `elemVarSet` funs
uses funs (App e1 e2) = uses funs e1 || uses funs e2
uses funs (Lam b body) = uses (funs `extendVarSet` b) body
uses funs (Let (NonRec _b letExpr) body)
= uses funs letExpr || uses funs body
uses funs (Case e _eId _ty alts)
= uses funs e || any (uses_alt funs) alts
uses _ _ = False
uses_alt funs (_, _bs, e) = uses funs e
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
= do { fn_var <- hoistExpr (fsLit "fn") expr DontInline
; zipf <- zipScalars arg_tys res_ty
; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
; clo_var <- hoistExpr (fsLit "clo") clo DontInline
; lclo <- liftPD (Var clo_var)
; return (Var clo_var, lclo)
}
vectLam :: Bool
-> Bool
-> CoreExprWithFVs
-> VM VExpr
vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
= do let (bs, body) = collectAnnValBinders expr
tyvars <- localTyVars
(vs, vvs) <- readLEnv $ \env ->
unzip [(var, vv) | var <- varSetElems fvs
, Just vv <- [lookupVarEnv (local_vars env) var]]
arg_tys <- mapM (vectType . idType) bs
res_ty <- vectType (exprType $ deAnnotate body)
buildClosures tyvars vvs arg_tys res_ty
. hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
$ do
lc <- builtin liftingContext
(vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
vbody' <- break_loop lc res_ty vbody
return $ vLams lc vbndrs vbody'
where
maybe_inline n | inline = Inline n
| otherwise = DontInline
break_loop lc ty (ve, le)
| loop_breaker
= do
empty <- emptyPD ty
lty <- mkPDataType ty
return (ve, mkWildCase (Var lc) intPrimTy lty
[(DEFAULT, [], le),
(LitAlt (mkMachInt 0), [], empty)])
| otherwise = return (ve, le)
vectLam _ _ _ = panic "vectLam"
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
(ppr $ deAnnotate e `mkTyApps` tys)
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
-> [(AltCon, [Var], CoreExprWithFVs)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
vect_dc <- maybeV (lookupDataCon dc)
let [pdata_dc] = tyConDataCons pdata_tc
let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
mk_wild_case expr ty dc bndrs body
= mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
vect_tc <- maybeV (lookupTyCon tycon)
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
sel_ty <- builtin (selTy arity)
sel_bndr <- newLocalVar (fsLit "sel") sel_ty
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
let [pdata_dc] = tyConDataCons pdata_tc
let (vect_bodies, lift_bodies) = unzip vbodies
vdummy <- newDummyVar (exprType vect_scrut)
ldummy <- newDummyVar (exprType lift_scrut)
let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
lc <- builtin liftingContext
lbody <- combinePD vty (Var lc) sel lift_bodies
let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
cmp DEFAULT DEFAULT = EQ
cmp DEFAULT _ = LT
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
= do
vect_dc <- maybeV (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
fvs = freeVarsOf body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
elems <- builtin (selElements arity ntag)
(vbndrs, vbody)
<- vectBndrsIn bndrs
. localV
$ do
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
(ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
let (vect_bndrs, lift_bndrs) = unzip vbndrs
return (vect_dc, vect_bndrs, lift_bndrs, vbody)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
pack_var len tags t v
= do
r <- lookupVar v
case r of
Local (vv, lv) ->
do
lv' <- cloneVar lv
expr <- packByTagPD (idType vv) (Var lv) len tags t
updLEnv (\env -> env { local_vars = extendVarEnv
(local_vars env) v (vv, lv') })
return [(NonRec lv' expr)]
_ -> return []