module Vectorise.Monad (
module Vectorise.Monad.Base,
module Vectorise.Monad.Naming,
module Vectorise.Monad.Local,
module Vectorise.Monad.Global,
module Vectorise.Monad.InstEnv,
initV,
liftBuiltinDs,
builtin,
builtins,
lookupVar,
lookupVar_maybe,
addGlobalParallelVar,
addGlobalParallelTyCon,
) where
import Vectorise.Monad.Base
import Vectorise.Monad.Naming
import Vectorise.Monad.Local
import Vectorise.Monad.Global
import Vectorise.Monad.InstEnv
import Vectorise.Builtins
import Vectorise.Env
import CoreSyn
import DsMonad
import HscTypes hiding ( MonadThings(..) )
import DynFlags
import MonadUtils (liftIO)
import InstEnv
import Class
import TyCon
import NameSet
import VarSet
import VarEnv
import Var
import Id
import Name
import ErrUtils
import Outputable
import Module
initV :: HscEnv
-> ModGuts
-> VectInfo
-> VM a
-> IO (Maybe (VectInfo, a))
initV hsc_env guts info thing_inside
= do { dumpIfVtTrace "Incoming VectInfo" (ppr info)
; let type_env = typeEnvFromEntities ids (mg_tcs guts) (mg_fam_insts guts)
; (_, Just res) <- initDs hsc_env (mg_module guts)
(mg_rdr_env guts) type_env
(mg_fam_inst_env guts) go
; case res of
Nothing
-> dumpIfVtTrace "Vectorisation FAILED!" empty
Just (info', _)
-> dumpIfVtTrace "Outgoing VectInfo" (ppr info')
; return res
}
where
dflags = hsc_dflags hsc_env
dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace
bindsToIds (NonRec v _) = [v]
bindsToIds (Rec binds) = map fst binds
ids = concatMap bindsToIds (mg_binds guts)
go
= do {
; builtins <- initBuiltins
; builtin_vars <- initBuiltinVars builtins
; eps <- liftIO $ hscEPS hsc_env
; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
instEnvs = InstEnvs (eps_inst_env eps)
(mg_inst_env guts)
(mkModuleSet (dep_orphs (mg_deps guts)))
builtin_pas = initClassDicts instEnvs (paClass builtins)
builtin_prs = initClassDicts instEnvs (prClass builtins)
; let genv = extendImportedVarsEnv builtin_vars
. setPAFunsEnv builtin_pas
. setPRFunsEnv builtin_prs
$ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags)
info (mg_vect_decls guts) instEnvs famInstEnvs
; r <- runVM thing_inside builtins genv emptyLocalEnv
; case r of
Yes genv _ x -> return $ Just (new_info genv, x)
No reason -> do { unqual <- mkPrintUnqualifiedDs
; liftIO $
printOutputForUser dflags unqual $
mkDumpDoc "Warning: vectorisation failure:" reason
; return Nothing
}
}
new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info
initClassDicts :: InstEnvs -> Class -> [(Name, Var)]
initClassDicts insts cls = map find $ classInstances insts cls
where
find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
| otherwise = pprPanic invalidInstance (ppr i)
invalidInstance = "Invalid DPH instance (overlapping in head constructor)"
liftBuiltinDs :: (Builtins -> DsM a) -> VM a
liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)}
builtin :: (Builtins -> a) -> VM a
builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))
builtins :: (a -> Builtins -> b) -> VM (a -> b)
builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))
lookupVar :: Var -> VM (Scope Var (Var, Var))
lookupVar v
= do { mb_res <- lookupVar_maybe v
; case mb_res of
Just x -> return x
Nothing ->
do dflags <- getDynFlags
dumpVar dflags v
}
lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var)))
lookupVar_maybe v
= do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
; case r of
Just e -> return $ Just (Local e)
Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
}
dumpVar :: DynFlags -> Var -> a
dumpVar dflags var
| Just _ <- isClassOpId_maybe var
= cantVectorise dflags "ClassOpId not vectorised:" (ppr var)
| otherwise
= cantVectorise dflags "Variable not vectorised:" (ppr var)
addGlobalParallelVar :: Var -> VM ()
addGlobalParallelVar var
= do { traceVt "addGlobalParallelVar" (ppr var)
; updGEnv $ \env -> env{global_parallel_vars = extendDVarSet (global_parallel_vars env) var}
}
addGlobalParallelTyCon :: TyCon -> VM ()
addGlobalParallelTyCon tycon
= do { traceVt "addGlobalParallelTyCon" (ppr tycon)
; updGEnv $ \env ->
env{global_parallel_tycons = extendNameSet (global_parallel_tycons env) (tyConName tycon)}
}