module DsArrows ( dsProcExpr ) where
#include "HsVersions.h"
import Match
import DsUtils
import DsMonad
import HsSyn hiding (collectPatBinders, collectPatsBinders, collectLStmtsBinders, collectLStmtBinders, collectStmtBinders )
import TcHsSyn
import DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import TcType
import TcEvidence
import CoreSyn
import CoreFVs
import CoreUtils
import MkCore
import DsBinds (dsHsWrapper)
import Name
import Var
import Id
import DataCon
import TysWiredIn
import BasicTypes
import PrelNames
import Outputable
import Bag
import VarSet
import SrcLoc
import ListSetOps( assocDefault )
import FastString
import Data.List
data DsCmdEnv = DsCmdEnv {
arr_id, compose_id, first_id, app_id, choice_id, loop_id :: CoreExpr
}
mkCmdEnv :: CmdSyntaxTable Id -> DsM ([CoreBind], DsCmdEnv)
mkCmdEnv tc_meths
= do { (meth_binds, prs) <- mapAndUnzipM mk_bind tc_meths
; return (meth_binds, DsCmdEnv {
arr_id = Var (find_meth prs arrAName),
compose_id = Var (find_meth prs composeAName),
first_id = Var (find_meth prs firstAName),
app_id = Var (find_meth prs appAName),
choice_id = Var (find_meth prs choiceAName),
loop_id = Var (find_meth prs loopAName)
}) }
where
mk_bind (std_name, expr)
= do { rhs <- dsExpr expr
; id <- newSysLocalDs (exprType rhs)
; return (NonRec id rhs, (std_name, id)) }
find_meth prs std_name
= assocDefault (mk_panic std_name) prs std_name
mk_panic std_name = pprPanic "mkCmdEnv" (ptext (sLit "Not found:") <+> ppr std_name)
do_arr :: DsCmdEnv -> Type -> Type -> CoreExpr -> CoreExpr
do_arr ids b_ty c_ty f = mkApps (arr_id ids) [Type b_ty, Type c_ty, f]
do_compose :: DsCmdEnv -> Type -> Type -> Type ->
CoreExpr -> CoreExpr -> CoreExpr
do_compose ids b_ty c_ty d_ty f g
= mkApps (compose_id ids) [Type b_ty, Type c_ty, Type d_ty, f, g]
do_first :: DsCmdEnv -> Type -> Type -> Type -> CoreExpr -> CoreExpr
do_first ids b_ty c_ty d_ty f
= mkApps (first_id ids) [Type b_ty, Type c_ty, Type d_ty, f]
do_app :: DsCmdEnv -> Type -> Type -> CoreExpr
do_app ids b_ty c_ty = mkApps (app_id ids) [Type b_ty, Type c_ty]
do_choice :: DsCmdEnv -> Type -> Type -> Type ->
CoreExpr -> CoreExpr -> CoreExpr
do_choice ids b_ty c_ty d_ty f g
= mkApps (choice_id ids) [Type b_ty, Type d_ty, Type c_ty, f, g]
do_loop :: DsCmdEnv -> Type -> Type -> Type -> CoreExpr -> CoreExpr
do_loop ids b_ty c_ty d_ty f
= mkApps (loop_id ids) [Type b_ty, Type d_ty, Type c_ty, f]
do_premap :: DsCmdEnv -> Type -> Type -> Type ->
CoreExpr -> CoreExpr -> CoreExpr
do_premap ids b_ty c_ty d_ty f g
= do_compose ids b_ty c_ty d_ty (do_arr ids b_ty c_ty f) g
mkFailExpr :: HsMatchContext Id -> Type -> DsM CoreExpr
mkFailExpr ctxt ty
= mkErrorAppDs pAT_ERROR_ID ty (matchContextErrString ctxt)
mkFstExpr :: Type -> Type -> DsM CoreExpr
mkFstExpr a_ty b_ty = do
a_var <- newSysLocalDs a_ty
b_var <- newSysLocalDs b_ty
pair_var <- newSysLocalDs (mkCorePairTy a_ty b_ty)
return (Lam pair_var
(coreCasePair pair_var a_var b_var (Var a_var)))
mkSndExpr :: Type -> Type -> DsM CoreExpr
mkSndExpr a_ty b_ty = do
a_var <- newSysLocalDs a_ty
b_var <- newSysLocalDs b_ty
pair_var <- newSysLocalDs (mkCorePairTy a_ty b_ty)
return (Lam pair_var
(coreCasePair pair_var a_var b_var (Var b_var)))
coreCaseTuple :: UniqSupply -> Id -> [Id] -> CoreExpr -> CoreExpr
coreCaseTuple uniqs scrut_var vars body
= mkTupleCase uniqs vars body scrut_var (Var scrut_var)
coreCasePair :: Id -> Id -> Id -> CoreExpr -> CoreExpr
coreCasePair scrut_var var1 var2 body
= Case (Var scrut_var) scrut_var (exprType body)
[(DataAlt (tupleCon BoxedTuple 2), [var1, var2], body)]
mkCorePairTy :: Type -> Type -> Type
mkCorePairTy t1 t2 = mkBoxedTupleTy [t1, t2]
mkCorePairExpr :: CoreExpr -> CoreExpr -> CoreExpr
mkCorePairExpr e1 e2 = mkCoreTup [e1, e2]
mkCoreUnitExpr :: CoreExpr
mkCoreUnitExpr = mkCoreTup []
envStackType :: [Id] -> Type -> Type
envStackType ids stack_ty = mkCorePairTy (mkBigCoreVarTupTy ids) stack_ty
splitTypeAt :: Int -> Type -> ([Type], Type)
splitTypeAt n ty
| n == 0 = ([], ty)
| otherwise = case tcTyConAppArgs ty of
[t, ty'] -> let (ts, ty_r) = splitTypeAt (n1) ty' in (t:ts, ty_r)
_ -> pprPanic "splitTypeAt" (ppr ty)
buildEnvStack :: [Id] -> Id -> CoreExpr
buildEnvStack env_ids stack_id
= mkCorePairExpr (mkBigCoreVarTup env_ids) (Var stack_id)
matchEnvStack :: [Id]
-> Id
-> CoreExpr
-> DsM CoreExpr
matchEnvStack env_ids stack_id body = do
uniqs <- newUniqueSupply
tup_var <- newSysLocalDs (mkBigCoreVarTupTy env_ids)
let match_env = coreCaseTuple uniqs tup_var env_ids body
pair_id <- newSysLocalDs (mkCorePairTy (idType tup_var) (idType stack_id))
return (Lam pair_id (coreCasePair pair_id tup_var stack_id match_env))
matchEnv :: [Id]
-> CoreExpr
-> DsM CoreExpr
matchEnv env_ids body = do
uniqs <- newUniqueSupply
tup_id <- newSysLocalDs (mkBigCoreVarTupTy env_ids)
return (Lam tup_id (coreCaseTuple uniqs tup_id env_ids body))
matchVarStack :: [Id] -> Id -> CoreExpr -> DsM (Id, CoreExpr)
matchVarStack [] stack_id body = return (stack_id, body)
matchVarStack (param_id:param_ids) stack_id body = do
(tail_id, tail_code) <- matchVarStack param_ids stack_id body
pair_id <- newSysLocalDs (mkCorePairTy (idType param_id) (idType tail_id))
return (pair_id, coreCasePair pair_id param_id tail_id tail_code)
mkHsEnvStackExpr :: [Id] -> Id -> LHsExpr Id
mkHsEnvStackExpr env_ids stack_id
= mkLHsTupleExpr [mkLHsVarTuple env_ids, nlHsVar stack_id]
dsProcExpr
:: LPat Id
-> LHsCmdTop Id
-> DsM CoreExpr
dsProcExpr pat (L _ (HsCmdTop cmd _unitTy cmd_ty ids)) = do
(meth_binds, meth_ids) <- mkCmdEnv ids
let locals = mkVarSet (collectPatBinders pat)
(core_cmd, _free_vars, env_ids) <- dsfixCmd meth_ids locals unitTy cmd_ty cmd
let env_ty = mkBigCoreVarTupTy env_ids
let env_stk_ty = mkCorePairTy env_ty unitTy
let env_stk_expr = mkCorePairExpr (mkBigCoreVarTup env_ids) mkCoreUnitExpr
fail_expr <- mkFailExpr ProcExpr env_stk_ty
var <- selectSimpleMatchVarL pat
match_code <- matchSimply (Var var) ProcExpr pat env_stk_expr fail_expr
let pat_ty = hsLPatType pat
proc_code = do_premap meth_ids pat_ty env_stk_ty cmd_ty
(Lam var match_code)
core_cmd
return (mkLets meth_binds proc_code)
dsLCmd :: DsCmdEnv -> IdSet -> Type -> Type -> LHsCmd Id -> [Id]
-> DsM (CoreExpr, IdSet)
dsLCmd ids local_vars stk_ty res_ty cmd env_ids
= dsCmd ids local_vars stk_ty res_ty (unLoc cmd) env_ids
dsCmd :: DsCmdEnv
-> IdSet
-> Type
-> Type
-> HsCmd Id
-> [Id]
-> DsM (CoreExpr,
IdSet)
dsCmd ids local_vars stack_ty res_ty
(HsCmdArrApp arrow arg arrow_ty HsFirstOrderApp _)
env_ids = do
let
(a_arg_ty, _res_ty') = tcSplitAppTy arrow_ty
(_a_ty, arg_ty) = tcSplitAppTy a_arg_ty
core_arrow <- dsLExpr arrow
core_arg <- dsLExpr arg
stack_id <- newSysLocalDs stack_ty
core_make_arg <- matchEnvStack env_ids stack_id core_arg
return (do_premap ids
(envStackType env_ids stack_ty)
arg_ty
res_ty
core_make_arg
core_arrow,
exprFreeIds core_arg `intersectVarSet` local_vars)
dsCmd ids local_vars stack_ty res_ty
(HsCmdArrApp arrow arg arrow_ty HsHigherOrderApp _)
env_ids = do
let
(a_arg_ty, _res_ty') = tcSplitAppTy arrow_ty
(_a_ty, arg_ty) = tcSplitAppTy a_arg_ty
core_arrow <- dsLExpr arrow
core_arg <- dsLExpr arg
stack_id <- newSysLocalDs stack_ty
core_make_pair <- matchEnvStack env_ids stack_id
(mkCorePairExpr core_arrow core_arg)
return (do_premap ids
(envStackType env_ids stack_ty)
(mkCorePairTy arrow_ty arg_ty)
res_ty
core_make_pair
(do_app ids arg_ty res_ty),
(exprFreeIds core_arrow `unionVarSet` exprFreeIds core_arg)
`intersectVarSet` local_vars)
dsCmd ids local_vars stack_ty res_ty (HsCmdApp cmd arg) env_ids = do
core_arg <- dsLExpr arg
let
arg_ty = exprType core_arg
stack_ty' = mkCorePairTy arg_ty stack_ty
(core_cmd, free_vars, env_ids')
<- dsfixCmd ids local_vars stack_ty' res_ty cmd
stack_id <- newSysLocalDs stack_ty
arg_id <- newSysLocalDs arg_ty
let
stack' = mkCorePairExpr (Var arg_id) (Var stack_id)
core_body = bindNonRec arg_id core_arg
(mkCorePairExpr (mkBigCoreVarTup env_ids') stack')
core_map <- matchEnvStack env_ids stack_id core_body
return (do_premap ids
(envStackType env_ids stack_ty)
(envStackType env_ids' stack_ty')
res_ty
core_map
core_cmd,
free_vars `unionVarSet`
(exprFreeIds core_arg `intersectVarSet` local_vars))
dsCmd ids local_vars stack_ty res_ty
(HsCmdLam (MG { mg_alts = [L _ (Match _ pats _
(GRHSs [L _ (GRHS [] body)] _ ))] }))
env_ids = do
let
pat_vars = mkVarSet (collectPatsBinders pats)
local_vars' = pat_vars `unionVarSet` local_vars
(pat_tys, stack_ty') = splitTypeAt (length pats) stack_ty
(core_body, free_vars, env_ids') <- dsfixCmd ids local_vars' stack_ty' res_ty body
param_ids <- mapM newSysLocalDs pat_tys
stack_id' <- newSysLocalDs stack_ty'
let
core_expr = buildEnvStack env_ids' stack_id'
in_ty = envStackType env_ids stack_ty
in_ty' = envStackType env_ids' stack_ty'
fail_expr <- mkFailExpr LambdaExpr in_ty'
match_code <- matchSimplys (map Var param_ids) LambdaExpr pats core_expr fail_expr
(stack_id, param_code) <- matchVarStack param_ids stack_id' match_code
select_code <- matchEnvStack env_ids stack_id param_code
return (do_premap ids in_ty in_ty' res_ty select_code core_body,
free_vars `minusVarSet` pat_vars)
dsCmd ids local_vars stack_ty res_ty (HsCmdPar cmd) env_ids
= dsLCmd ids local_vars stack_ty res_ty cmd env_ids
dsCmd ids local_vars stack_ty res_ty (HsCmdIf mb_fun cond then_cmd else_cmd)
env_ids = do
core_cond <- dsLExpr cond
(core_then, fvs_then, then_ids) <- dsfixCmd ids local_vars stack_ty res_ty then_cmd
(core_else, fvs_else, else_ids) <- dsfixCmd ids local_vars stack_ty res_ty else_cmd
stack_id <- newSysLocalDs stack_ty
either_con <- dsLookupTyCon eitherTyConName
left_con <- dsLookupDataCon leftDataConName
right_con <- dsLookupDataCon rightDataConName
let mk_left_expr ty1 ty2 e = mkCoreConApps left_con [Type ty1, Type ty2, e]
mk_right_expr ty1 ty2 e = mkCoreConApps right_con [Type ty1, Type ty2, e]
in_ty = envStackType env_ids stack_ty
then_ty = envStackType then_ids stack_ty
else_ty = envStackType else_ids stack_ty
sum_ty = mkTyConApp either_con [then_ty, else_ty]
fvs_cond = exprFreeIds core_cond `intersectVarSet` local_vars
core_left = mk_left_expr then_ty else_ty (buildEnvStack then_ids stack_id)
core_right = mk_right_expr then_ty else_ty (buildEnvStack else_ids stack_id)
core_if <- case mb_fun of
Just fun -> do { core_fun <- dsExpr fun
; matchEnvStack env_ids stack_id $
mkCoreApps core_fun [core_cond, core_left, core_right] }
Nothing -> matchEnvStack env_ids stack_id $
mkIfThenElse core_cond core_left core_right
return (do_premap ids in_ty sum_ty res_ty
core_if
(do_choice ids then_ty else_ty res_ty core_then core_else),
fvs_cond `unionVarSet` fvs_then `unionVarSet` fvs_else)
dsCmd ids local_vars stack_ty res_ty
(HsCmdCase exp (MG { mg_alts = matches, mg_arg_tys = arg_tys, mg_origin = origin }))
env_ids = do
stack_id <- newSysLocalDs stack_ty
let
leaves = concatMap leavesMatch matches
make_branch (leaf, bound_vars) = do
(core_leaf, _fvs, leaf_ids) <-
dsfixCmd ids (bound_vars `unionVarSet` local_vars) stack_ty res_ty leaf
return ([mkHsEnvStackExpr leaf_ids stack_id],
envStackType leaf_ids stack_ty,
core_leaf)
branches <- mapM make_branch leaves
either_con <- dsLookupTyCon eitherTyConName
left_con <- dsLookupDataCon leftDataConName
right_con <- dsLookupDataCon rightDataConName
let
left_id = HsVar (dataConWrapId left_con)
right_id = HsVar (dataConWrapId right_con)
left_expr ty1 ty2 e = noLoc $ HsApp (noLoc $ HsWrap (mkWpTyApps [ty1, ty2]) left_id ) e
right_expr ty1 ty2 e = noLoc $ HsApp (noLoc $ HsWrap (mkWpTyApps [ty1, ty2]) right_id) e
merge_branches (builds1, in_ty1, core_exp1)
(builds2, in_ty2, core_exp2)
= (map (left_expr in_ty1 in_ty2) builds1 ++
map (right_expr in_ty1 in_ty2) builds2,
mkTyConApp either_con [in_ty1, in_ty2],
do_choice ids in_ty1 in_ty2 res_ty core_exp1 core_exp2)
(leaves', sum_ty, core_choices) = foldb merge_branches branches
(_, matches') = mapAccumL (replaceLeavesMatch res_ty) leaves' matches
in_ty = envStackType env_ids stack_ty
core_body <- dsExpr (HsCase exp (MG { mg_alts = matches', mg_arg_tys = arg_tys
, mg_res_ty = sum_ty, mg_origin = origin }))
core_matches <- matchEnvStack env_ids stack_id core_body
return (do_premap ids in_ty sum_ty res_ty core_matches core_choices,
exprFreeIds core_body `intersectVarSet` local_vars)
dsCmd ids local_vars stack_ty res_ty (HsCmdLet binds body) env_ids = do
let
defined_vars = mkVarSet (collectLocalBinders binds)
local_vars' = defined_vars `unionVarSet` local_vars
(core_body, _free_vars, env_ids') <- dsfixCmd ids local_vars' stack_ty res_ty body
stack_id <- newSysLocalDs stack_ty
core_binds <- dsLocalBinds binds (buildEnvStack env_ids' stack_id)
core_map <- matchEnvStack env_ids stack_id core_binds
return (do_premap ids
(envStackType env_ids stack_ty)
(envStackType env_ids' stack_ty)
res_ty
core_map
core_body,
exprFreeIds core_binds `intersectVarSet` local_vars)
dsCmd ids local_vars stack_ty res_ty (HsCmdDo stmts _) env_ids = do
(core_stmts, env_ids') <- dsCmdDo ids local_vars res_ty stmts env_ids
let env_ty = mkBigCoreVarTupTy env_ids
core_fst <- mkFstExpr env_ty stack_ty
return (do_premap ids
(mkCorePairTy env_ty stack_ty)
env_ty
res_ty
core_fst
core_stmts,
env_ids')
dsCmd _ids local_vars _stack_ty _res_ty (HsCmdArrForm op _ args) env_ids = do
let env_ty = mkBigCoreVarTupTy env_ids
core_op <- dsLExpr op
(core_args, fv_sets) <- mapAndUnzipM (dsTrimCmdArg local_vars env_ids) args
return (mkApps (App core_op (Type env_ty)) core_args,
unionVarSets fv_sets)
dsCmd ids local_vars stack_ty res_ty (HsCmdCast coercion cmd) env_ids = do
(core_cmd, env_ids') <- dsCmd ids local_vars stack_ty res_ty cmd env_ids
wrapped_cmd <- dsHsWrapper (mkWpCast coercion) core_cmd
return (wrapped_cmd, env_ids')
dsCmd _ _ _ _ _ c = pprPanic "dsCmd" (ppr c)
dsTrimCmdArg
:: IdSet
-> [Id]
-> LHsCmdTop Id
-> DsM (CoreExpr,
IdSet)
dsTrimCmdArg local_vars env_ids (L _ (HsCmdTop cmd stack_ty cmd_ty ids)) = do
(meth_binds, meth_ids) <- mkCmdEnv ids
(core_cmd, free_vars, env_ids') <- dsfixCmd meth_ids local_vars stack_ty cmd_ty cmd
stack_id <- newSysLocalDs stack_ty
trim_code <- matchEnvStack env_ids stack_id (buildEnvStack env_ids' stack_id)
let
in_ty = envStackType env_ids stack_ty
in_ty' = envStackType env_ids' stack_ty
arg_code = if env_ids' == env_ids then core_cmd else
do_premap meth_ids in_ty in_ty' cmd_ty trim_code core_cmd
return (mkLets meth_binds arg_code, free_vars)
dsfixCmd
:: DsCmdEnv
-> IdSet
-> Type
-> Type
-> LHsCmd Id
-> DsM (CoreExpr,
IdSet,
[Id])
dsfixCmd ids local_vars stk_ty cmd_ty cmd
= trimInput (dsLCmd ids local_vars stk_ty cmd_ty cmd)
trimInput
:: ([Id] -> DsM (CoreExpr, IdSet))
-> DsM (CoreExpr,
IdSet,
[Id])
trimInput build_arrow
= fixDs (\ ~(_,_,env_ids) -> do
(core_cmd, free_vars) <- build_arrow env_ids
return (core_cmd, free_vars, varSetElems free_vars))
dsCmdDo :: DsCmdEnv
-> IdSet
-> Type
-> [CmdLStmt Id]
-> [Id]
-> DsM (CoreExpr,
IdSet)
dsCmdDo _ _ _ [] _ = panic "dsCmdDo"
dsCmdDo ids local_vars res_ty [L _ (LastStmt body _)] env_ids = do
(core_body, env_ids') <- dsLCmd ids local_vars unitTy res_ty body env_ids
let env_ty = mkBigCoreVarTupTy env_ids
env_var <- newSysLocalDs env_ty
let core_map = Lam env_var (mkCorePairExpr (Var env_var) mkCoreUnitExpr)
return (do_premap ids
env_ty
(mkCorePairTy env_ty unitTy)
res_ty
core_map
core_body,
env_ids')
dsCmdDo ids local_vars res_ty (stmt:stmts) env_ids = do
let
bound_vars = mkVarSet (collectLStmtBinders stmt)
local_vars' = bound_vars `unionVarSet` local_vars
(core_stmts, _, env_ids') <- trimInput (dsCmdDo ids local_vars' res_ty stmts)
(core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids' stmt env_ids
return (do_compose ids
(mkBigCoreVarTupTy env_ids)
(mkBigCoreVarTupTy env_ids')
res_ty
core_stmt
core_stmts,
fv_stmt)
dsCmdLStmt :: DsCmdEnv -> IdSet -> [Id] -> CmdLStmt Id -> [Id]
-> DsM (CoreExpr, IdSet)
dsCmdLStmt ids local_vars out_ids cmd env_ids
= dsCmdStmt ids local_vars out_ids (unLoc cmd) env_ids
dsCmdStmt
:: DsCmdEnv
-> IdSet
-> [Id]
-> CmdStmt Id
-> [Id]
-> DsM (CoreExpr,
IdSet)
dsCmdStmt ids local_vars out_ids (BodyStmt cmd _ _ c_ty) env_ids = do
(core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars unitTy c_ty cmd
core_mux <- matchEnv env_ids
(mkCorePairExpr
(mkCorePairExpr (mkBigCoreVarTup env_ids1) mkCoreUnitExpr)
(mkBigCoreVarTup out_ids))
let
in_ty = mkBigCoreVarTupTy env_ids
in_ty1 = mkCorePairTy (mkBigCoreVarTupTy env_ids1) unitTy
out_ty = mkBigCoreVarTupTy out_ids
before_c_ty = mkCorePairTy in_ty1 out_ty
after_c_ty = mkCorePairTy c_ty out_ty
snd_fn <- mkSndExpr c_ty out_ty
return (do_premap ids in_ty before_c_ty out_ty core_mux $
do_compose ids before_c_ty after_c_ty out_ty
(do_first ids in_ty1 c_ty out_ty core_cmd) $
do_arr ids after_c_ty out_ty snd_fn,
extendVarSetList fv_cmd out_ids)
dsCmdStmt ids local_vars out_ids (BindStmt pat cmd _ _) env_ids = do
(core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars unitTy (hsLPatType pat) cmd
let
pat_ty = hsLPatType pat
pat_vars = mkVarSet (collectPatBinders pat)
env_ids2 = varSetElems (mkVarSet out_ids `minusVarSet` pat_vars)
env_ty2 = mkBigCoreVarTupTy env_ids2
core_mux <- matchEnv env_ids
(mkCorePairExpr
(mkCorePairExpr (mkBigCoreVarTup env_ids1) mkCoreUnitExpr)
(mkBigCoreVarTup env_ids2))
env_id <- newSysLocalDs env_ty2
uniqs <- newUniqueSupply
let
after_c_ty = mkCorePairTy pat_ty env_ty2
out_ty = mkBigCoreVarTupTy out_ids
body_expr = coreCaseTuple uniqs env_id env_ids2 (mkBigCoreVarTup out_ids)
fail_expr <- mkFailExpr (StmtCtxt DoExpr) out_ty
pat_id <- selectSimpleMatchVarL pat
match_code <- matchSimply (Var pat_id) (StmtCtxt DoExpr) pat body_expr fail_expr
pair_id <- newSysLocalDs after_c_ty
let
proj_expr = Lam pair_id (coreCasePair pair_id pat_id env_id match_code)
let
in_ty = mkBigCoreVarTupTy env_ids
in_ty1 = mkCorePairTy (mkBigCoreVarTupTy env_ids1) unitTy
in_ty2 = mkBigCoreVarTupTy env_ids2
before_c_ty = mkCorePairTy in_ty1 in_ty2
return (do_premap ids in_ty before_c_ty out_ty core_mux $
do_compose ids before_c_ty after_c_ty out_ty
(do_first ids in_ty1 pat_ty in_ty2 core_cmd) $
do_arr ids after_c_ty out_ty proj_expr,
fv_cmd `unionVarSet` (mkVarSet out_ids `minusVarSet` pat_vars))
dsCmdStmt ids local_vars out_ids (LetStmt binds) env_ids = do
core_binds <- dsLocalBinds binds (mkBigCoreVarTup out_ids)
core_map <- matchEnv env_ids core_binds
return (do_arr ids
(mkBigCoreVarTupTy env_ids)
(mkBigCoreVarTupTy out_ids)
core_map,
exprFreeIds core_binds `intersectVarSet` local_vars)
dsCmdStmt ids local_vars out_ids
(RecStmt { recS_stmts = stmts
, recS_later_ids = later_ids, recS_rec_ids = rec_ids
, recS_later_rets = later_rets, recS_rec_rets = rec_rets })
env_ids = do
let
env2_id_set = mkVarSet out_ids `minusVarSet` mkVarSet later_ids
env2_ids = varSetElems env2_id_set
env2_ty = mkBigCoreVarTupTy env2_ids
uniqs <- newUniqueSupply
env2_id <- newSysLocalDs env2_ty
let
later_ty = mkBigCoreVarTupTy later_ids
post_pair_ty = mkCorePairTy later_ty env2_ty
post_loop_body = coreCaseTuple uniqs env2_id env2_ids (mkBigCoreVarTup out_ids)
post_loop_fn <- matchEnvStack later_ids env2_id post_loop_body
(core_loop, env1_id_set, env1_ids)
<- dsRecCmd ids local_vars stmts later_ids later_rets rec_ids rec_rets
let
env1_ty = mkBigCoreVarTupTy env1_ids
pre_pair_ty = mkCorePairTy env1_ty env2_ty
pre_loop_body = mkCorePairExpr (mkBigCoreVarTup env1_ids)
(mkBigCoreVarTup env2_ids)
pre_loop_fn <- matchEnv env_ids pre_loop_body
let
env_ty = mkBigCoreVarTupTy env_ids
out_ty = mkBigCoreVarTupTy out_ids
core_body = do_premap ids env_ty pre_pair_ty out_ty
pre_loop_fn
(do_compose ids pre_pair_ty post_pair_ty out_ty
(do_first ids env1_ty later_ty env2_ty
core_loop)
(do_arr ids post_pair_ty out_ty
post_loop_fn))
return (core_body, env1_id_set `unionVarSet` env2_id_set)
dsCmdStmt _ _ _ _ s = pprPanic "dsCmdStmt" (ppr s)
dsRecCmd
:: DsCmdEnv
-> IdSet
-> [CmdLStmt Id]
-> [Id]
-> [HsExpr Id]
-> [Id]
-> [HsExpr Id]
-> DsM (CoreExpr,
IdSet,
[Id])
dsRecCmd ids local_vars stmts later_ids later_rets rec_ids rec_rets = do
let
later_id_set = mkVarSet later_ids
rec_id_set = mkVarSet rec_ids
local_vars' = rec_id_set `unionVarSet` later_id_set `unionVarSet` local_vars
core_later_rets <- mapM dsExpr later_rets
core_rec_rets <- mapM dsExpr rec_rets
let
out_ids = varSetElems (unionVarSets (map exprFreeIds (core_later_rets ++ core_rec_rets)))
out_ty = mkBigCoreVarTupTy out_ids
later_tuple = mkBigCoreTup core_later_rets
later_ty = mkBigCoreVarTupTy later_ids
rec_tuple = mkBigCoreTup core_rec_rets
rec_ty = mkBigCoreVarTupTy rec_ids
out_pair = mkCorePairExpr later_tuple rec_tuple
out_pair_ty = mkCorePairTy later_ty rec_ty
mk_pair_fn <- matchEnv out_ids out_pair
(core_stmts, fv_stmts, env_ids) <- dsfixCmdStmts ids local_vars' out_ids stmts
rec_id <- newSysLocalDs rec_ty
let
env1_id_set = fv_stmts `minusVarSet` rec_id_set
env1_ids = varSetElems env1_id_set
env1_ty = mkBigCoreVarTupTy env1_ids
in_pair_ty = mkCorePairTy env1_ty rec_ty
core_body = mkBigCoreTup (map selectVar env_ids)
where
selectVar v
| v `elemVarSet` rec_id_set
= mkTupleSelector rec_ids v rec_id (Var rec_id)
| otherwise = Var v
squash_pair_fn <- matchEnvStack env1_ids rec_id core_body
let
env_ty = mkBigCoreVarTupTy env_ids
core_loop = do_loop ids env1_ty later_ty rec_ty
(do_premap ids in_pair_ty env_ty out_pair_ty
squash_pair_fn
(do_compose ids env_ty out_ty out_pair_ty
core_stmts
(do_arr ids out_ty out_pair_ty mk_pair_fn)))
return (core_loop, env1_id_set, env1_ids)
dsfixCmdStmts
:: DsCmdEnv
-> IdSet
-> [Id]
-> [CmdLStmt Id]
-> DsM (CoreExpr,
IdSet,
[Id])
dsfixCmdStmts ids local_vars out_ids stmts
= trimInput (dsCmdStmts ids local_vars out_ids stmts)
dsCmdStmts
:: DsCmdEnv
-> IdSet
-> [Id]
-> [CmdLStmt Id]
-> [Id]
-> DsM (CoreExpr,
IdSet)
dsCmdStmts ids local_vars out_ids [stmt] env_ids
= dsCmdLStmt ids local_vars out_ids stmt env_ids
dsCmdStmts ids local_vars out_ids (stmt:stmts) env_ids = do
let
bound_vars = mkVarSet (collectLStmtBinders stmt)
local_vars' = bound_vars `unionVarSet` local_vars
(core_stmts, _fv_stmts, env_ids') <- dsfixCmdStmts ids local_vars' out_ids stmts
(core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids' stmt env_ids
return (do_compose ids
(mkBigCoreVarTupTy env_ids)
(mkBigCoreVarTupTy env_ids')
(mkBigCoreVarTupTy out_ids)
core_stmt
core_stmts,
fv_stmt)
dsCmdStmts _ _ _ [] _ = panic "dsCmdStmts []"
matchSimplys :: [CoreExpr]
-> HsMatchContext Name
-> [LPat Id]
-> CoreExpr
-> CoreExpr
-> DsM CoreExpr
matchSimplys [] _ctxt [] result_expr _fail_expr = return result_expr
matchSimplys (exp:exps) ctxt (pat:pats) result_expr fail_expr = do
match_code <- matchSimplys exps ctxt pats result_expr fail_expr
matchSimply exp ctxt pat match_code fail_expr
matchSimplys _ _ _ _ _ = panic "matchSimplys"
leavesMatch :: LMatch Id (Located (body Id)) -> [(Located (body Id), IdSet)]
leavesMatch (L _ (Match _ pats _ (GRHSs grhss binds)))
= let
defined_vars = mkVarSet (collectPatsBinders pats)
`unionVarSet`
mkVarSet (collectLocalBinders binds)
in
[(body,
mkVarSet (collectLStmtsBinders stmts)
`unionVarSet` defined_vars)
| L _ (GRHS stmts body) <- grhss]
replaceLeavesMatch
:: Type
-> [Located (body' Id)]
-> LMatch Id (Located (body Id))
-> ([Located (body' Id)],
LMatch Id (Located (body' Id)))
replaceLeavesMatch _res_ty leaves (L loc (Match mf pat mt (GRHSs grhss binds)))
= let
(leaves', grhss') = mapAccumL replaceLeavesGRHS leaves grhss
in
(leaves', L loc (Match mf pat mt (GRHSs grhss' binds)))
replaceLeavesGRHS
:: [Located (body' Id)]
-> LGRHS Id (Located (body Id))
-> ([Located (body' Id)],
LGRHS Id (Located (body' Id)))
replaceLeavesGRHS (leaf:leaves) (L loc (GRHS stmts _))
= (leaves, L loc (GRHS stmts leaf))
replaceLeavesGRHS [] _ = panic "replaceLeavesGRHS []"
foldb :: (a -> a -> a) -> [a] -> a
foldb _ [] = error "foldb of empty list"
foldb _ [x] = x
foldb f xs = foldb f (fold_pairs xs)
where
fold_pairs [] = []
fold_pairs [x] = [x]
fold_pairs (x1:x2:xs) = f x1 x2:fold_pairs xs
collectPatBinders :: LPat Id -> [Id]
collectPatBinders pat = collectl pat []
collectPatsBinders :: [LPat Id] -> [Id]
collectPatsBinders pats = foldr collectl [] pats
collectl :: LPat Id -> [Id] -> [Id]
collectl (L _ pat) bndrs
= go pat
where
go (VarPat var) = var : bndrs
go (WildPat _) = bndrs
go (LazyPat pat) = collectl pat bndrs
go (BangPat pat) = collectl pat bndrs
go (AsPat (L _ a) pat) = a : collectl pat bndrs
go (ParPat pat) = collectl pat bndrs
go (ListPat pats _ _) = foldr collectl bndrs pats
go (PArrPat pats _) = foldr collectl bndrs pats
go (TuplePat pats _ _) = foldr collectl bndrs pats
go (ConPatIn _ ps) = foldr collectl bndrs (hsConPatArgs ps)
go (ConPatOut {pat_args=ps, pat_binds=ds}) =
collectEvBinders ds
++ foldr collectl bndrs (hsConPatArgs ps)
go (LitPat _) = bndrs
go (NPat _ _ _) = bndrs
go (NPlusKPat (L _ n) _ _ _) = n : bndrs
go (SigPatIn pat _) = collectl pat bndrs
go (SigPatOut pat _) = collectl pat bndrs
go (CoPat _ pat _) = collectl (noLoc pat) bndrs
go (ViewPat _ pat _) = collectl pat bndrs
go p@(SplicePat {}) = pprPanic "collectl/go" (ppr p)
go p@(QuasiQuotePat {}) = pprPanic "collectl/go" (ppr p)
collectEvBinders :: TcEvBinds -> [Id]
collectEvBinders (EvBinds bs) = foldrBag add_ev_bndr [] bs
collectEvBinders (TcEvBinds {}) = panic "ToDo: collectEvBinders"
add_ev_bndr :: EvBind -> [Id] -> [Id]
add_ev_bndr (EvBind b _) bs | isId b = b:bs
| otherwise = bs
collectLStmtsBinders :: [LStmt Id body] -> [Id]
collectLStmtsBinders = concatMap collectLStmtBinders
collectLStmtBinders :: LStmt Id body -> [Id]
collectLStmtBinders = collectStmtBinders . unLoc
collectStmtBinders :: Stmt Id body -> [Id]
collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
collectStmtBinders (LetStmt binds) = collectLocalBinders binds
collectStmtBinders (BodyStmt {}) = []
collectStmtBinders (LastStmt {}) = []
collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders
$ [ s | ParStmtBlock ss _ _ <- xs, s <- ss]
collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt { recS_later_ids = later_ids }) = later_ids