module Vectorise.Monad (
module Vectorise.Monad.Base,
module Vectorise.Monad.Naming,
module Vectorise.Monad.Local,
module Vectorise.Monad.Global,
module Vectorise.Monad.InstEnv,
initV,
liftBuiltinDs,
builtin,
builtins,
lookupVar,
lookupVar_maybe,
addGlobalParallelVar,
addGlobalParallelTyCon,
) where
import Vectorise.Monad.Base
import Vectorise.Monad.Naming
import Vectorise.Monad.Local
import Vectorise.Monad.Global
import Vectorise.Monad.InstEnv
import Vectorise.Builtins
import Vectorise.Env
import CoreSyn
import TcRnMonad
import DsMonad
import HscTypes hiding ( MonadThings(..) )
import DynFlags
import MonadUtils (liftIO)
import InstEnv
import Class
import TyCon
import NameSet
import VarSet
import VarEnv
import Var
import Id
import Name
import ErrUtils
import Outputable
import Module
import Control.Monad (join)
initV :: HscEnv
-> ModGuts
-> VectInfo
-> VM a
-> IO (Maybe (VectInfo, a))
initV hsc_env guts info thing_inside
= do { dumpIfVtTrace "Incoming VectInfo" (ppr info)
; (_, res) <- initDsWithModGuts hsc_env guts go
; case join res of
Nothing
-> dumpIfVtTrace "Vectorisation FAILED!" empty
Just (info', _)
-> dumpIfVtTrace "Outgoing VectInfo" (ppr info')
; return $ join res
}
where
dflags = hsc_dflags hsc_env
dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace
bindsToIds (NonRec v _) = [v]
bindsToIds (Rec binds) = map fst binds
ids = concatMap bindsToIds (mg_binds guts)
go
= do {
; builtins <- initBuiltins
; builtin_vars <- initBuiltinVars builtins
; eps <- liftIO $ hscEPS hsc_env
; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
instEnvs = InstEnvs (eps_inst_env eps)
(mg_inst_env guts)
(mkModuleSet (dep_orphs (mg_deps guts)))
builtin_pas = initClassDicts instEnvs (paClass builtins)
builtin_prs = initClassDicts instEnvs (prClass builtins)
; let genv = extendImportedVarsEnv builtin_vars
. setPAFunsEnv builtin_pas
. setPRFunsEnv builtin_prs
$ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags)
info (mg_vect_decls guts) instEnvs famInstEnvs
; r <- runVM thing_inside builtins genv emptyLocalEnv
; case r of
Yes genv _ x -> return $ Just (new_info genv, x)
No reason -> do { unqual <- mkPrintUnqualifiedDs
; liftIO $
printOutputForUser dflags unqual $
mkDumpDoc "Warning: vectorisation failure:" reason
; return Nothing
}
}
new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info
initClassDicts :: InstEnvs -> Class -> [(Name, Var)]
initClassDicts insts cls = map find $ classInstances insts cls
where
find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
| otherwise = pprPanic invalidInstance (ppr i)
invalidInstance = "Invalid DPH instance (overlapping in head constructor)"
liftBuiltinDs :: (Builtins -> DsM a) -> VM a
liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)}
builtin :: (Builtins -> a) -> VM a
builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))
builtins :: (a -> Builtins -> b) -> VM (a -> b)
builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))
lookupVar :: Var -> VM (Scope Var (Var, Var))
lookupVar v
= do { mb_res <- lookupVar_maybe v
; case mb_res of
Just x -> return x
Nothing ->
do dflags <- getDynFlags
dumpVar dflags v
}
lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var)))
lookupVar_maybe v
= do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
; case r of
Just e -> return $ Just (Local e)
Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
}
dumpVar :: DynFlags -> Var -> a
dumpVar dflags var
| Just _ <- isClassOpId_maybe var
= cantVectorise dflags "ClassOpId not vectorised:" (ppr var)
| otherwise
= cantVectorise dflags "Variable not vectorised:" (ppr var)
addGlobalParallelVar :: Var -> VM ()
addGlobalParallelVar var
= do { traceVt "addGlobalParallelVar" (ppr var)
; updGEnv $ \env -> env{global_parallel_vars = extendDVarSet (global_parallel_vars env) var}
}
addGlobalParallelTyCon :: TyCon -> VM ()
addGlobalParallelTyCon tycon
= do { traceVt "addGlobalParallelTyCon" (ppr tycon)
; updGEnv $ \env ->
env{global_parallel_tycons = extendNameSet (global_parallel_tycons env) (tyConName tycon)}
}