module Vectorise.Exp
(
vectPolyExpr
, vectDictExpr
, vectScalarFun
, vectScalarDFun
)
where
#include "HsVersions.h"
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Convert
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreUtils
import MkCore
import CoreSyn
import CoreFVs
import Class
import DataCon
import TyCon
import TcType
import Type
import PrelNames
import Var
import VarEnv
import VarSet
import Id
import BasicTypes( isStrongLoopBreaker )
import Literal
import TysWiredIn
import TysPrim
import Outputable
import FastString
import Control.Monad
import Control.Applicative
import Data.Maybe
import Data.List
import TcRnMonad (doptM)
import DynFlags
import Util
vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns expr Nothing
= do
{ vectAvoidance <- liftDs $ doptM Opt_AvoidVect
; vi <- vectAvoidInfo expr
; (expr', vi') <-
if vectAvoidance
then do
{ (expr', vi') <- encapsulateScalars vi expr
; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr')
; return (expr', vi')
}
else return (expr, vi)
; vectPolyExpr loop_breaker recFns expr' (Just vi')
}
vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit]))
= do
{ (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit)
; return (inline, isScalarFn, vTick tickish expr')
}
vectPolyExpr loop_breaker recFns expr (Just vit)
= do
{ let (tvs, mono) = collectAnnTypeBinders expr
vit' = stripLevels (length tvs) vit
; arity <- polyArity tvs
; polyAbstract tvs $ \args ->
do
{ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit'
; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
}
}
where
stripLevels 0 vit = vit
stripLevels n (VITNode _ [vit]) = stripLevels (n 1) vit
stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit))
encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
encapsulateScalars vit ce@(_, AnnType _ty)
= return (ce, vit)
encapsulateScalars vit ce@(_, AnnVar _v)
= return (ce, vit)
encapsulateScalars vit ce@(_, AnnLit _)
= return (ce, vit)
encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr)
= do { (extExpr, vit') <- encapsulateScalars vit expr
; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
}
encapsulateScalars _ (_fvs, AnnTick _tck _expr)
= panic "encapsulateScalar AnnTick doesn't match up"
encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> do { let (e', vit') = liftSimple vit ce
; return (e', vit')
}
_ -> do { (extExpr, vit') <- encapsulateScalars vit expr
; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
}
}
encapsulateScalars _ (_fvs, AnnLam _bndr _expr)
= panic "encapsulateScalars AnnLam doesn't match up"
encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> do { let (e', vt') = liftSimple vt ce
; return (e', vt')
}
_ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1
; (etaCe2, vit2') <- encapsulateScalars vit2 ce2
; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
}
}
encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2)
= panic "encapsulateScalars AnnApp doesn't match up"
encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ liftSimple vt ce
_ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut
; extAltsVits <- zipWithM expAlt altVits alts
; let (extAlts, altVits') = unzip extAltsVits
; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
}
}
where
expAlt vt (con, bndrs, expr)
= do { (extExpr, vt') <- encapsulateScalars vt expr
; return ((con, bndrs, extExpr), vt')
}
encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts)
= panic "encapsulateScalars AnnCase doesn't match up"
encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ liftSimple vt ce
_ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1
; (extExpr2, vt2') <- encapsulateScalars vt2 expr2
; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
}
}
encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
= panic "encapsulateScalars AnnLet nonrec doesn't match up"
encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ liftSimple vt ce
_ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
; let (extBnds, vtBnds') = unzip extBndsVts
; (extExpr, vtB') <- encapsulateScalars vtB expr
; let vt' = VITNode vi (vtB':vtBnds')
; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
}
}
where
expBndg vit (bndr, expr)
= do { (extExpr, vit') <- encapsulateScalars vit expr
; return ((bndr, extExpr), vit')
}
encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2)
= panic "encapsulateScalars AnnLet rec doesn't match up"
encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion)
= do { (extExpr, vit') <- encapsulateScalars vit expr
; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
}
encapsulateScalars _ (_fvs, AnnCast _expr _coercion)
= panic "encapsulateScalars AnnCast rec doesn't match up"
encapsulateScalars _ _
= panic "encapsulateScalars case not handled"
liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
| Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
(not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon])
= ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
where
(alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits
liftSimple viTree ae@(fvs, _annEx)
= (mkAnnApps (mkAnnLams ae vars) vars, viTree')
where
mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
mkViTreeApps vi [] = vi
mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
vars = varSetElems fvs
viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
mkAnnLam bndr ce = AnnLam bndr ce
mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
mkAnnLams (fv, aex') [] = (fv, aex')
mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
mkAnnApps (fv, aex') [] = (fv, aex')
mkAnnApps ae (v:vs) =
let
(fv, aex') = mkAnnApps ae vs
in (extendVarSet fv v, mkAnnApp (fv, aex') v)
vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
vectExpr (_, AnnVar v) _
= vectVar v
vectExpr (_, AnnLit lit) _
= vectConst $ Lit lit
vectExpr e@(_, AnnLam bndr _) vt
| isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
| otherwise = do dflags <- getDynFlags
cantVectorise dflags "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e))
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
| v == pAT_ERROR_ID
= do { (vty, lty) <- vectAndLiftType ty
; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
}
where
err' = deAnnotate err
vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
| isAnnTypeArg arg
= vectPolyApp e
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
| Just con <- isDataConId_maybe v
, is_special_con con
= do
let vexpr = App (Var v) (Lit lit)
lexpr <- liftPD vexpr
return (vexpr, lexpr)
where
is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
| isPredTy arg_ty
= vectPolyApp e
| otherwise
= do {
; varg_ty <- vectType arg_ty
; vres_ty <- vectType res_ty
; vfn <- vectExpr fn vit1
; varg <- vectExpr arg vit2
; mkClosureApp varg_ty vres_ty vfn varg
}
where
(arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
vectExpr (_, AnnCase scrut bndr ty alts) vt
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts vt
| otherwise = do dflags <- getDynFlags
cantVectorise dflags "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
= do
vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1)
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
return $ vLet (vNonRec vbndr vrhs) vbody
vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
= do
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWith3M vect_rhs bndrs rhss vtBnds)
(vectExpr body vtB)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip bs
vect_rhs bndr rhs vt = localV
. inBind bndr
. liftM (\(_,_,z)->z)
$ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt)
zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
= liftM (vTick tickish) (vectExpr expr vit)
vectExpr (_, AnnType ty) _
= liftM vType (vectType ty)
vectExpr e vit = do dflags <- getDynFlags
cantVectorise dflags "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit))
vectFnExpr :: Bool
-> Bool
-> [Var]
-> CoreExprWithFVs
-> VITree
-> VM (Inline, Bool, VExpr)
vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
| isId bndr
&& isPredTy (idType bndr)
= do { vBndr <- vectBndr bndr
; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
}
| isId bndr
= mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt)
`orElseV`
mark inlineMe False (vectLam inline loop_breaker expr vt)
vectFnExpr _ _ _ e vt
= mark DontInline False $ vectExpr e vt
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
vectPolyApp :: CoreExprWithFVs -> VM VExpr
vectPolyApp e0
= case e4 of
(_, AnnVar var)
-> do {
; vVar <- lookupVar var
; traceVt "vectPolyApp of" (ppr var)
; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
; vTysOuter <- mapM vectType tysOuter
; vTysInner <- mapM vectType tysInner
; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
; case vVar of
Local (vv, lv)
-> do { MASSERT( null dictsInner )
; traceVt " LOCAL" (text "")
; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
}
Global vv
| isDictComp var
-> do {
; ve <- if null dictsInner
then
return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
else
reconstructOuter
(Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
; traceVt " GLOBAL (dict):" (ppr ve)
; vectConst ve
}
| otherwise
-> do { MASSERT( null dictsInner )
; ve <- reconstructOuter (Var vv)
; traceVt " GLOBAL (non-dict):" (ppr ve)
; vectConst ve
}
}
_ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
where
(e1, dictsOuter) = collectAnnDictArgs e0
(e2, tysOuter) = collectAnnTypeArgs e1
(e3, dictsInner) = collectAnnDictArgs e2
(e4, tysInner) = collectAnnTypeArgs e3
isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
vectDictExpr :: CoreExpr -> VM CoreExpr
vectDictExpr (Var var)
= do { mb_scope <- lookupVar_maybe var
; case mb_scope of
Nothing -> return $ Var var
Just (Local (vVar, _)) -> return $ Var vVar
Just (Global vVar) -> return $ Var vVar
}
vectDictExpr (Lit lit)
= pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
vectDictExpr (Lam bndr e)
= Lam bndr <$> vectDictExpr e
vectDictExpr (App fn arg)
= App <$> vectDictExpr fn <*> vectDictExpr arg
vectDictExpr (Case e bndr ty alts)
= Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
where
vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
where
dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
vectDictAltCon (LitAlt lit) = return $ LitAlt lit
vectDictAltCon DEFAULT = return DEFAULT
vectDictExpr (Let bnd body)
= Let <$> vectDictBind bnd <*> vectDictExpr body
where
vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
vectDictExpr e@(Cast _e _coe)
= pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
vectDictExpr (Tick tickish e)
= Tick tickish <$> vectDictExpr e
vectDictExpr (Type ty)
= Type <$> vectType ty
vectDictExpr (Coercion coe)
= pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
vectScalarFunMaybe :: CoreExpr
-> VITree
-> VM VExpr
vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr
vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function")
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
{ traceVt "vectScalarFun" (ppr expr)
; let (arg_tys, res_ty) = splitFunTys (exprType expr)
; mkScalarFun arg_tys res_ty expr
}
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
| isPredTy res_ty
= do { vExpr <- vectDictExpr expr
; return (vExpr, unused)
}
| otherwise
= do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
; fn_var <- hoistExpr (fsLit "fn") expr DontInline
; zipf <- zipScalars arg_tys res_ty
; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
; clo_var <- hoistExpr (fsLit "clo") clo DontInline
; lclo <- liftPD (Var clo_var)
; return (Var clo_var, lclo)
}
where
unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
vectScalarDFun :: Var
-> VM CoreExpr
vectScalarDFun var
= do {
; mapM_ defLocalTyVar tvs
; vTheta <- mapM vectType theta
; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
; let vThetaVars = varsToCoreExprs vThetaBndr
; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
; thetaExprs <- zipWithM unVectDict theta vThetaVars
; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
selIds
; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
; Just vDataCon <- lookupDataCon dataCon
; vTys <- mapM vectType tys
; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
; return $ mkLams (tvs ++ vThetaBndr) vBody
}
where
ty = varType var
(tvs, theta, pty) = tcSplitSigmaTy ty
(cls, tys) = tcSplitDFunHead pty
selIds = classAllSelIds cls
dataCon = classDataCon cls
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e
= do { vTys <- mapM vectType tys
; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
; scOps <- zipWithM fromVect methTys meths
; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
}
where
(tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
cls = case tyConClass_maybe tycon of
Just cls -> cls
Nothing -> panic "Vectorise.Exp.unVectDict: no class"
selIds = classAllSelIds cls
vectLam :: Bool
-> Bool
-> CoreExprWithFVs
-> VITree
-> VM VExpr
vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
= do { let (bndrs, body) = collectAnnValBinders expr
; tyvars <- localTyVars
; vfvs <- readLEnv $ \env ->
[ (var, fromJust mb_vv)
| var <- varSetElems fvs
, let mb_vv = lookupVarEnv (local_vars env) var
, isJust mb_vv
]
; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
(_fvs_dict, vfvs_dict) = unzip vvs_dict
(fvs_nondict, vfvs_nondict) = unzip vvs_nondict
; arg_tys <- mapM (vectType . idType) bndrs
; res_ty <- vectType (exprType $ deAnnotate body)
; let arity = length fvs_nondict + length bndrs
vfvs_dict' = map vectorised vfvs_dict
; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
. hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
$ do {
; lc <- builtin liftingContext
; let viBody = stripLams expr vi
; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
; vbody' <- break_loop lc res_ty vbody
; return $ vLams lc vbndrs vbody'
}
}
where
stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
stripLams _ vi = vi
maybe_inline n | inline = Inline n
| otherwise = DontInline
break_loop lc ty (ve, le)
| loop_breaker
= do { empty <- emptyPD ty
; lty <- mkPDataType ty
; return (ve, mkWildCase (Var lc) intPrimTy lty
[(DEFAULT, [], le),
(LitAlt (mkMachInt 0), [], empty)])
}
| otherwise = return (ve, le)
vectLam _ _ _ _ = panic "vectLam"
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
-> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
= do
vscrut <- vectExpr scrut scrutVit
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
= do
vscrut <- vectExpr scrut scrutVit
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
= do
(vty, lty) <- vectAndLiftType ty
vexpr <- vectExpr scrut scrutVit
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body altVit
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
mk_wild_case expr ty dc bndrs body
= mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
= do
vect_tc <- maybeV tyConErr (lookupTyCon tycon)
(vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
sel_ty <- builtin (selTy arity)
sel_bndr <- newLocalVar (fsLit "sel") sel_ty
let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
vexpr <- vectExpr scrut scrutVit
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
vdummy <- newDummyVar (exprType vect_scrut)
ldummy <- newDummyVar (exprType lift_scrut)
let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
lc <- builtin liftingContext
lbody <- combinePD vty (Var lc) sel lift_bodies
let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
where
tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
cmp DEFAULT DEFAULT = EQ
cmp DEFAULT _ = LT
cmp _ DEFAULT = GT
cmp _ _ = panic "vectAlgCase/cmp"
proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
= do
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let ntag = dataConTagZ vect_dc
tag = mkDataConTag vect_dc
fvs = freeVarsOf body `delVarSetList` bndrs
sel_tags <- liftM (`App` sel) (builtin (selTags arity))
lc <- builtin liftingContext
elems <- builtin (selElements arity ntag)
(vbndrs, vbody)
<- vectBndrsIn bndrs
. localV
$ do
binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
(ve, le) <- vectExpr body vi
return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
let (vect_bndrs, lift_bndrs) = unzip vbndrs
return (vect_dc, vect_bndrs, lift_bndrs, vbody)
where
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
pack_var len tags t v
= do
r <- lookupVar v
case r of
Local (vv, lv) ->
do
lv' <- cloneVar lv
expr <- packByTagPD (idType vv) (Var lv) len tags t
updLEnv (\env -> env { local_vars = extendVarEnv
(local_vars env) v (vv, lv') })
return [(NonRec lv' expr)]
_ -> return []
vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _)
= pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
data VectAvoidInfo = VIParr
| VISimple
| VIComplex
| VIEncaps
deriving (Eq, Show)
data VITree = VITNode VectAvoidInfo [VITree]
deriving (Show)
anyVIPArr :: [VITree] -> Bool
anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr))
vectAvoidInfo :: CoreExprWithFVs -> VM VITree
vectAvoidInfo ce@(_, AnnVar v)
= do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
; return $ VITNode vi []
}
vectAvoidInfo ce@(_, AnnLit _)
= do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
; return $ VITNode vi []
}
vectAvoidInfo ce@(_, AnnApp e1 e2)
= do { vt1 <- vectAvoidInfo e1
; vt2 <- vectAvoidInfo e2
; vi <- if anyVIPArr [vt1, vt2]
then return VIParr
else vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce vi [vt1, vt2]
; return $ VITNode vi [vt1, vt2]
}
vectAvoidInfo ce@(_, AnnLam _var body)
= do { vt@(VITNode vi _) <- vectAvoidInfo body
; viTrace ce vi [vt]
; let resultVI | vi == VIParr = VIParr
| otherwise = VIComplex
; return $ VITNode resultVI [vt]
}
vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
= do { vtE <- vectAvoidInfo expr
; vtB <- vectAvoidInfo body
; vi <- if anyVIPArr [vtE, vtB]
then return VIParr
else vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce vi [vtE, vtB]
; return $ VITNode vi [vtE, vtB]
}
vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body)
= do { let (_, exprs) = unzip bnds
; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs
; if (anyVIPArr vtBnds)
then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs
; vtB <- vectAvoidInfo body
; return (VITNode VIParr (vtB: vtBnds'))
}
else do { vtB@(VITNode vib _) <- vectAvoidInfo body
; ni <- if (vib == VIParr)
then return VIParr
else vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce ni (vtB : vtBnds)
; return $ VITNode ni (vtB : vtBnds)
}
}
vectAvoidInfo ce@(_, AnnCase expr _var _ty alts)
= do { vtExpr <- vectAvoidInfo expr
; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts
; ni <- if anyVIPArr (vtExpr : vtAlts)
then return VIParr
else vectAvoidInfoType $ exprType $ deAnnotate ce
; viTrace ce ni (vtExpr : vtAlts)
; return $ VITNode ni (vtExpr: vtAlts)
}
vectAvoidInfo (_, AnnCast expr _)
= do { vt@(VITNode vi _) <- vectAvoidInfo expr
; return $ VITNode vi [vt]
}
vectAvoidInfo (_, AnnTick _ expr)
= do { vt@(VITNode vi _) <- vectAvoidInfo expr
; return $ VITNode vi [vt]
}
vectAvoidInfo (_, AnnType {})
= return $ VITNode VISimple []
vectAvoidInfo (_, AnnCoercion {})
= return $ VITNode VISimple []
vectAvoidInfoType :: Type -> VM VectAvoidInfo
vectAvoidInfoType ty
| maybeParrTy ty = return VIParr
| otherwise
= do { sType <- isSimpleType ty
; if sType
then return VISimple
else return VIComplex
}
maybeParrTy :: Type -> Bool
maybeParrTy ty
| Just ty' <- coreView ty = maybeParrTy ty'
| Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
|| or (map maybeParrTy ts)
maybeParrTy _ = False
isSimpleType :: Type -> VM Bool
isSimpleType ty
| Just (c, _cs) <- splitTyConApp_maybe ty
= return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
| Nothing <- splitTyConApp_maybe ty
= return False
isSimpleType ty
= pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
varsSimple :: VarSet -> VM Bool
varsSimple vs
= do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
; return $ and varTypes
}
viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM ()
viTrace ce vi vTs
= traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]")
(ppr $ deAnnotate ce)