module RnExpr (
rnLExpr, rnExpr, rnStmts
) where
#include "HsVersions.h"
import RnBinds ( rnLocalBindsAndThen, rnLocalValBindsLHS, rnLocalValBindsRHS,
rnMatchGroup, rnGRHS, makeMiniFixityEnv)
import HsSyn
import TcRnMonad
import Module ( getModule )
import RnEnv
import RnSplice ( rnBracket, rnSpliceExpr, checkThLocalName )
import RnTypes
import RnPat
import DynFlags
import PrelNames
import BasicTypes
import Name
import NameSet
import RdrName
import UniqSet
import Data.List
import Util
import ListSetOps ( removeDups )
import ErrUtils
import Outputable
import SrcLoc
import FastString
import Control.Monad
import TysWiredIn ( nilDataConName )
import qualified GHC.LanguageExtensions as LangExt
import Data.Ord
import Data.Array
rnExprs :: [LHsExpr RdrName] -> RnM ([LHsExpr Name], FreeVars)
rnExprs ls = rnExprs' ls emptyUniqSet
where
rnExprs' [] acc = return ([], acc)
rnExprs' (expr:exprs) acc =
do { (expr', fvExpr) <- rnLExpr expr
; let acc' = acc `plusFV` fvExpr
; (exprs', fvExprs) <- acc' `seq` rnExprs' exprs acc'
; return (expr':exprs', fvExprs) }
rnLExpr :: LHsExpr RdrName -> RnM (LHsExpr Name, FreeVars)
rnLExpr = wrapLocFstM rnExpr
rnExpr :: HsExpr RdrName -> RnM (HsExpr Name, FreeVars)
finishHsVar :: Located Name -> RnM (HsExpr Name, FreeVars)
finishHsVar (L l name)
= do { this_mod <- getModule
; when (nameIsLocalOrFrom this_mod name) $
checkThLocalName name
; return (HsVar (L l name), unitFV name) }
rnUnboundVar :: RdrName -> RnM (HsExpr Name, FreeVars)
rnUnboundVar v
= do { if isUnqual v
then
do { let occ = rdrNameOcc v
; uv <- if startsWithUnderscore occ
then return (TrueExprHole occ)
else OutOfScope occ <$> getGlobalRdrEnv
; return (HsUnboundVar uv, emptyFVs) }
else
do { n <- reportUnboundName v
; return (HsVar (noLoc n), emptyFVs) } }
rnExpr (HsVar (L l v))
= do { opt_DuplicateRecordFields <- xoptM LangExt.DuplicateRecordFields
; mb_name <- lookupOccRn_overloaded opt_DuplicateRecordFields v
; case mb_name of {
Nothing -> rnUnboundVar v ;
Just (Left name)
| name == nilDataConName
-> rnExpr (ExplicitList placeHolderType Nothing [])
| otherwise
-> finishHsVar (L l name) ;
Just (Right [f@(FieldOcc (L _ fn) s)]) ->
return (HsRecFld (ambiguousFieldOcc (FieldOcc (L l fn) s))
, unitFV (selectorFieldOcc f)) ;
Just (Right fs@(_:_:_)) -> return (HsRecFld (Ambiguous (L l v)
PlaceHolder)
, mkFVs (map selectorFieldOcc fs));
Just (Right []) -> error "runExpr/HsVar" } }
rnExpr (HsIPVar v)
= return (HsIPVar v, emptyFVs)
rnExpr (HsOverLabel v)
= return (HsOverLabel v, emptyFVs)
rnExpr (HsLit lit@(HsString src s))
= do { opt_OverloadedStrings <- xoptM LangExt.OverloadedStrings
; if opt_OverloadedStrings then
rnExpr (HsOverLit (mkHsIsString src s placeHolderType))
else do {
; rnLit lit
; return (HsLit lit, emptyFVs) } }
rnExpr (HsLit lit)
= do { rnLit lit
; return (HsLit lit, emptyFVs) }
rnExpr (HsOverLit lit)
= do { (lit', fvs) <- rnOverLit lit
; return (HsOverLit lit', fvs) }
rnExpr (HsApp fun arg)
= do { (fun',fvFun) <- rnLExpr fun
; (arg',fvArg) <- rnLExpr arg
; return (HsApp fun' arg', fvFun `plusFV` fvArg) }
rnExpr (HsAppType fun arg)
= do { (fun',fvFun) <- rnLExpr fun
; (arg',fvArg) <- rnHsWcType HsTypeCtx arg
; return (HsAppType fun' arg', fvFun `plusFV` fvArg) }
rnExpr (OpApp e1 op _ e2)
= do { (e1', fv_e1) <- rnLExpr e1
; (e2', fv_e2) <- rnLExpr e2
; (op', fv_op) <- rnLExpr op
; fixity <- case op' of
L _ (HsVar (L _ n)) -> lookupFixityRn n
L _ (HsRecFld f) -> lookupFieldFixityRn f
_ -> return (Fixity (show minPrecedence) minPrecedence InfixL)
; final_e <- mkOpAppRn e1' op' fixity e2'
; return (final_e, fv_e1 `plusFV` fv_op `plusFV` fv_e2) }
rnExpr (NegApp e _)
= do { (e', fv_e) <- rnLExpr e
; (neg_name, fv_neg) <- lookupSyntaxName negateName
; final_e <- mkNegAppRn e' neg_name
; return (final_e, fv_e `plusFV` fv_neg) }
rnExpr e@(HsBracket br_body) = rnBracket e br_body
rnExpr (HsSpliceE splice) = rnSpliceExpr splice
rnExpr (HsPar (L loc (section@(SectionL {}))))
= do { (section', fvs) <- rnSection section
; return (HsPar (L loc section'), fvs) }
rnExpr (HsPar (L loc (section@(SectionR {}))))
= do { (section', fvs) <- rnSection section
; return (HsPar (L loc section'), fvs) }
rnExpr (HsPar e)
= do { (e', fvs_e) <- rnLExpr e
; return (HsPar e', fvs_e) }
rnExpr expr@(SectionL {})
= do { addErr (sectionErr expr); rnSection expr }
rnExpr expr@(SectionR {})
= do { addErr (sectionErr expr); rnSection expr }
rnExpr (HsCoreAnn src ann expr)
= do { (expr', fvs_expr) <- rnLExpr expr
; return (HsCoreAnn src ann expr', fvs_expr) }
rnExpr (HsSCC src lbl expr)
= do { (expr', fvs_expr) <- rnLExpr expr
; return (HsSCC src lbl expr', fvs_expr) }
rnExpr (HsTickPragma src info srcInfo expr)
= do { (expr', fvs_expr) <- rnLExpr expr
; return (HsTickPragma src info srcInfo expr', fvs_expr) }
rnExpr (HsLam matches)
= do { (matches', fvMatch) <- rnMatchGroup LambdaExpr rnLExpr matches
; return (HsLam matches', fvMatch) }
rnExpr (HsLamCase _arg matches)
= do { (matches', fvs_ms) <- rnMatchGroup CaseAlt rnLExpr matches
; return (HsLamCase placeHolderType matches', fvs_ms) }
rnExpr (HsCase expr matches)
= do { (new_expr, e_fvs) <- rnLExpr expr
; (new_matches, ms_fvs) <- rnMatchGroup CaseAlt rnLExpr matches
; return (HsCase new_expr new_matches, e_fvs `plusFV` ms_fvs) }
rnExpr (HsLet (L l binds) expr)
= rnLocalBindsAndThen binds $ \binds' _ -> do
{ (expr',fvExpr) <- rnLExpr expr
; return (HsLet (L l binds') expr', fvExpr) }
rnExpr (HsDo do_or_lc (L l stmts) _)
= do { ((stmts', _), fvs) <-
rnStmtsWithPostProcessing do_or_lc rnLExpr
postProcessStmtsForApplicativeDo stmts
(\ _ -> return ((), emptyFVs))
; return ( HsDo do_or_lc (L l stmts') placeHolderType, fvs ) }
rnExpr (ExplicitList _ _ exps)
= do { opt_OverloadedLists <- xoptM LangExt.OverloadedLists
; (exps', fvs) <- rnExprs exps
; if opt_OverloadedLists
then do {
; (from_list_n_name, fvs') <- lookupSyntaxName fromListNName
; return (ExplicitList placeHolderType (Just from_list_n_name) exps'
, fvs `plusFV` fvs') }
else
return (ExplicitList placeHolderType Nothing exps', fvs) }
rnExpr (ExplicitPArr _ exps)
= do { (exps', fvs) <- rnExprs exps
; return (ExplicitPArr placeHolderType exps', fvs) }
rnExpr (ExplicitTuple tup_args boxity)
= do { checkTupleSection tup_args
; checkTupSize (length tup_args)
; (tup_args', fvs) <- mapAndUnzipM rnTupArg tup_args
; return (ExplicitTuple tup_args' boxity, plusFVs fvs) }
where
rnTupArg (L l (Present e)) = do { (e',fvs) <- rnLExpr e
; return (L l (Present e'), fvs) }
rnTupArg (L l (Missing _)) = return (L l (Missing placeHolderType)
, emptyFVs)
rnExpr (RecordCon { rcon_con_name = con_id
, rcon_flds = rec_binds@(HsRecFields { rec_dotdot = dd }) })
= do { con_lname@(L _ con_name) <- lookupLocatedOccRn con_id
; (flds, fvs) <- rnHsRecFields (HsRecFieldCon con_name) mk_hs_var rec_binds
; (flds', fvss) <- mapAndUnzipM rn_field flds
; let rec_binds' = HsRecFields { rec_flds = flds', rec_dotdot = dd }
; return (RecordCon { rcon_con_name = con_lname, rcon_flds = rec_binds'
, rcon_con_expr = noPostTcExpr, rcon_con_like = PlaceHolder }
, fvs `plusFV` plusFVs fvss `addOneFV` con_name) }
where
mk_hs_var l n = HsVar (L l n)
rn_field (L l fld) = do { (arg', fvs) <- rnLExpr (hsRecFieldArg fld)
; return (L l (fld { hsRecFieldArg = arg' }), fvs) }
rnExpr (RecordUpd { rupd_expr = expr, rupd_flds = rbinds })
= do { (expr', fvExpr) <- rnLExpr expr
; (rbinds', fvRbinds) <- rnHsRecUpdFields rbinds
; return (RecordUpd { rupd_expr = expr', rupd_flds = rbinds'
, rupd_cons = PlaceHolder, rupd_in_tys = PlaceHolder
, rupd_out_tys = PlaceHolder, rupd_wrap = PlaceHolder }
, fvExpr `plusFV` fvRbinds) }
rnExpr (ExprWithTySig expr pty)
= do { (pty', fvTy) <- rnHsSigWcType ExprWithTySigCtx pty
; (expr', fvExpr) <- bindSigTyVarsFV (hsWcScopedTvs pty') $
rnLExpr expr
; return (ExprWithTySig expr' pty', fvExpr `plusFV` fvTy) }
rnExpr (HsIf _ p b1 b2)
= do { (p', fvP) <- rnLExpr p
; (b1', fvB1) <- rnLExpr b1
; (b2', fvB2) <- rnLExpr b2
; (mb_ite, fvITE) <- lookupIfThenElse
; return (HsIf mb_ite p' b1' b2', plusFVs [fvITE, fvP, fvB1, fvB2]) }
rnExpr (HsMultiIf _ty alts)
= do { (alts', fvs) <- mapFvRn (rnGRHS IfAlt rnLExpr) alts
; return (HsMultiIf placeHolderType alts', fvs) }
rnExpr (ArithSeq _ _ seq)
= do { opt_OverloadedLists <- xoptM LangExt.OverloadedLists
; (new_seq, fvs) <- rnArithSeq seq
; if opt_OverloadedLists
then do {
; (from_list_name, fvs') <- lookupSyntaxName fromListName
; return (ArithSeq noPostTcExpr (Just from_list_name) new_seq, fvs `plusFV` fvs') }
else
return (ArithSeq noPostTcExpr Nothing new_seq, fvs) }
rnExpr (PArrSeq _ seq)
= do { (new_seq, fvs) <- rnArithSeq seq
; return (PArrSeq noPostTcExpr new_seq, fvs) }
rnExpr EWildPat = return (hsHoleExpr, emptyFVs)
rnExpr e@(EAsPat {}) =
patSynErr e (text "Did you mean to enable TypeApplications?")
rnExpr e@(EViewPat {}) = patSynErr e empty
rnExpr e@(ELazyPat {}) = patSynErr e empty
rnExpr e@(HsStatic expr) = do
target <- fmap hscTarget getDynFlags
case target of
HscInterpreted -> addErr $ sep
[ text "The static form is not supported in interpreted mode."
, text "Please use -fobject-code."
]
_ -> return ()
(expr',fvExpr) <- rnLExpr expr
stage <- getStage
case stage of
Brack _ _ -> return ()
Splice _ -> addErr $ sep
[ text "static forms cannot be used in splices:"
, nest 2 $ ppr e
]
_ -> do
let isTopLevelName n = isExternalName n || isWiredInName n
case nameSetElems $ filterNameSet
(\n -> not (isTopLevelName n || isUnboundName n))
fvExpr of
[] -> return ()
fvNonGlobal -> addErr $ cat
[ text $ "Only identifiers of top-level bindings can "
++ "appear in the body of the static form:"
, nest 2 $ ppr e
, text "but the following identifiers were found instead:"
, nest 2 $ vcat $ map ppr fvNonGlobal
]
return (HsStatic expr', fvExpr)
rnExpr (HsProc pat body)
= newArrowScope $
rnPat ProcExpr pat $ \ pat' -> do
{ (body',fvBody) <- rnCmdTop body
; return (HsProc pat' body', fvBody) }
rnExpr e@(HsArrApp {}) = arrowFail e
rnExpr e@(HsArrForm {}) = arrowFail e
rnExpr other = pprPanic "rnExpr: unexpected expression" (ppr other)
hsHoleExpr :: HsExpr id
hsHoleExpr = HsUnboundVar (TrueExprHole (mkVarOcc "_"))
arrowFail :: HsExpr RdrName -> RnM (HsExpr Name, FreeVars)
arrowFail e
= do { addErr (vcat [ text "Arrow command found where an expression was expected:"
, nest 2 (ppr e) ])
; return (hsHoleExpr, emptyFVs) }
rnSection :: HsExpr RdrName -> RnM (HsExpr Name, FreeVars)
rnSection section@(SectionR op expr)
= do { (op', fvs_op) <- rnLExpr op
; (expr', fvs_expr) <- rnLExpr expr
; checkSectionPrec InfixR section op' expr'
; return (SectionR op' expr', fvs_op `plusFV` fvs_expr) }
rnSection section@(SectionL expr op)
= do { (expr', fvs_expr) <- rnLExpr expr
; (op', fvs_op) <- rnLExpr op
; checkSectionPrec InfixL section op' expr'
; return (SectionL expr' op', fvs_op `plusFV` fvs_expr) }
rnSection other = pprPanic "rnSection" (ppr other)
rnCmdArgs :: [LHsCmdTop RdrName] -> RnM ([LHsCmdTop Name], FreeVars)
rnCmdArgs [] = return ([], emptyFVs)
rnCmdArgs (arg:args)
= do { (arg',fvArg) <- rnCmdTop arg
; (args',fvArgs) <- rnCmdArgs args
; return (arg':args', fvArg `plusFV` fvArgs) }
rnCmdTop :: LHsCmdTop RdrName -> RnM (LHsCmdTop Name, FreeVars)
rnCmdTop = wrapLocFstM rnCmdTop'
where
rnCmdTop' (HsCmdTop cmd _ _ _)
= do { (cmd', fvCmd) <- rnLCmd cmd
; let cmd_names = [arrAName, composeAName, firstAName] ++
nameSetElems (methodNamesCmd (unLoc cmd'))
; (cmd_names', cmd_fvs) <- lookupSyntaxNames cmd_names
; return (HsCmdTop cmd' placeHolderType placeHolderType
(cmd_names `zip` cmd_names'),
fvCmd `plusFV` cmd_fvs) }
rnLCmd :: LHsCmd RdrName -> RnM (LHsCmd Name, FreeVars)
rnLCmd = wrapLocFstM rnCmd
rnCmd :: HsCmd RdrName -> RnM (HsCmd Name, FreeVars)
rnCmd (HsCmdArrApp arrow arg _ ho rtl)
= do { (arrow',fvArrow) <- select_arrow_scope (rnLExpr arrow)
; (arg',fvArg) <- rnLExpr arg
; return (HsCmdArrApp arrow' arg' placeHolderType ho rtl,
fvArrow `plusFV` fvArg) }
where
select_arrow_scope tc = case ho of
HsHigherOrderApp -> tc
HsFirstOrderApp -> escapeArrowScope tc
rnCmd (HsCmdArrForm op (Just _) [arg1, arg2])
= do { (op',fv_op) <- escapeArrowScope (rnLExpr op)
; let L _ (HsVar (L _ op_name)) = op'
; (arg1',fv_arg1) <- rnCmdTop arg1
; (arg2',fv_arg2) <- rnCmdTop arg2
; fixity <- lookupFixityRn op_name
; final_e <- mkOpFormRn arg1' op' fixity arg2'
; return (final_e, fv_arg1 `plusFV` fv_op `plusFV` fv_arg2) }
rnCmd (HsCmdArrForm op fixity cmds)
= do { (op',fvOp) <- escapeArrowScope (rnLExpr op)
; (cmds',fvCmds) <- rnCmdArgs cmds
; return (HsCmdArrForm op' fixity cmds', fvOp `plusFV` fvCmds) }
rnCmd (HsCmdApp fun arg)
= do { (fun',fvFun) <- rnLCmd fun
; (arg',fvArg) <- rnLExpr arg
; return (HsCmdApp fun' arg', fvFun `plusFV` fvArg) }
rnCmd (HsCmdLam matches)
= do { (matches', fvMatch) <- rnMatchGroup LambdaExpr rnLCmd matches
; return (HsCmdLam matches', fvMatch) }
rnCmd (HsCmdPar e)
= do { (e', fvs_e) <- rnLCmd e
; return (HsCmdPar e', fvs_e) }
rnCmd (HsCmdCase expr matches)
= do { (new_expr, e_fvs) <- rnLExpr expr
; (new_matches, ms_fvs) <- rnMatchGroup CaseAlt rnLCmd matches
; return (HsCmdCase new_expr new_matches, e_fvs `plusFV` ms_fvs) }
rnCmd (HsCmdIf _ p b1 b2)
= do { (p', fvP) <- rnLExpr p
; (b1', fvB1) <- rnLCmd b1
; (b2', fvB2) <- rnLCmd b2
; (mb_ite, fvITE) <- lookupIfThenElse
; return (HsCmdIf mb_ite p' b1' b2', plusFVs [fvITE, fvP, fvB1, fvB2]) }
rnCmd (HsCmdLet (L l binds) cmd)
= rnLocalBindsAndThen binds $ \ binds' _ -> do
{ (cmd',fvExpr) <- rnLCmd cmd
; return (HsCmdLet (L l binds') cmd', fvExpr) }
rnCmd (HsCmdDo (L l stmts) _)
= do { ((stmts', _), fvs) <-
rnStmts ArrowExpr rnLCmd stmts (\ _ -> return ((), emptyFVs))
; return ( HsCmdDo (L l stmts') placeHolderType, fvs ) }
rnCmd cmd@(HsCmdWrap {}) = pprPanic "rnCmd" (ppr cmd)
type CmdNeeds = FreeVars
methodNamesLCmd :: LHsCmd Name -> CmdNeeds
methodNamesLCmd = methodNamesCmd . unLoc
methodNamesCmd :: HsCmd Name -> CmdNeeds
methodNamesCmd (HsCmdArrApp _arrow _arg _ HsFirstOrderApp _rtl)
= emptyFVs
methodNamesCmd (HsCmdArrApp _arrow _arg _ HsHigherOrderApp _rtl)
= unitFV appAName
methodNamesCmd (HsCmdArrForm {}) = emptyFVs
methodNamesCmd (HsCmdWrap _ cmd) = methodNamesCmd cmd
methodNamesCmd (HsCmdPar c) = methodNamesLCmd c
methodNamesCmd (HsCmdIf _ _ c1 c2)
= methodNamesLCmd c1 `plusFV` methodNamesLCmd c2 `addOneFV` choiceAName
methodNamesCmd (HsCmdLet _ c) = methodNamesLCmd c
methodNamesCmd (HsCmdDo (L _ stmts) _) = methodNamesStmts stmts
methodNamesCmd (HsCmdApp c _) = methodNamesLCmd c
methodNamesCmd (HsCmdLam match) = methodNamesMatch match
methodNamesCmd (HsCmdCase _ matches)
= methodNamesMatch matches `addOneFV` choiceAName
methodNamesMatch :: MatchGroup Name (LHsCmd Name) -> FreeVars
methodNamesMatch (MG { mg_alts = L _ ms })
= plusFVs (map do_one ms)
where
do_one (L _ (Match _ _ _ grhss)) = methodNamesGRHSs grhss
methodNamesGRHSs :: GRHSs Name (LHsCmd Name) -> FreeVars
methodNamesGRHSs (GRHSs grhss _) = plusFVs (map methodNamesGRHS grhss)
methodNamesGRHS :: Located (GRHS Name (LHsCmd Name)) -> CmdNeeds
methodNamesGRHS (L _ (GRHS _ rhs)) = methodNamesLCmd rhs
methodNamesStmts :: [Located (StmtLR Name Name (LHsCmd Name))] -> FreeVars
methodNamesStmts stmts = plusFVs (map methodNamesLStmt stmts)
methodNamesLStmt :: Located (StmtLR Name Name (LHsCmd Name)) -> FreeVars
methodNamesLStmt = methodNamesStmt . unLoc
methodNamesStmt :: StmtLR Name Name (LHsCmd Name) -> FreeVars
methodNamesStmt (LastStmt cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (BodyStmt cmd _ _ _) = methodNamesLCmd cmd
methodNamesStmt (BindStmt _ cmd _ _ _) = methodNamesLCmd cmd
methodNamesStmt (RecStmt { recS_stmts = stmts }) =
methodNamesStmts stmts `addOneFV` loopAName
methodNamesStmt (LetStmt {}) = emptyFVs
methodNamesStmt (ParStmt {}) = emptyFVs
methodNamesStmt (TransStmt {}) = emptyFVs
methodNamesStmt ApplicativeStmt{} = emptyFVs
rnArithSeq :: ArithSeqInfo RdrName -> RnM (ArithSeqInfo Name, FreeVars)
rnArithSeq (From expr)
= do { (expr', fvExpr) <- rnLExpr expr
; return (From expr', fvExpr) }
rnArithSeq (FromThen expr1 expr2)
= do { (expr1', fvExpr1) <- rnLExpr expr1
; (expr2', fvExpr2) <- rnLExpr expr2
; return (FromThen expr1' expr2', fvExpr1 `plusFV` fvExpr2) }
rnArithSeq (FromTo expr1 expr2)
= do { (expr1', fvExpr1) <- rnLExpr expr1
; (expr2', fvExpr2) <- rnLExpr expr2
; return (FromTo expr1' expr2', fvExpr1 `plusFV` fvExpr2) }
rnArithSeq (FromThenTo expr1 expr2 expr3)
= do { (expr1', fvExpr1) <- rnLExpr expr1
; (expr2', fvExpr2) <- rnLExpr expr2
; (expr3', fvExpr3) <- rnLExpr expr3
; return (FromThenTo expr1' expr2' expr3',
plusFVs [fvExpr1, fvExpr2, fvExpr3]) }
rnStmts :: Outputable (body RdrName)
=> HsStmtContext Name
-> (Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> [LStmt RdrName (Located (body RdrName))]
-> ([Name] -> RnM (thing, FreeVars))
-> RnM (([LStmt Name (Located (body Name))], thing), FreeVars)
rnStmts ctxt rnBody = rnStmtsWithPostProcessing ctxt rnBody noPostProcessStmts
rnStmtsWithPostProcessing
:: Outputable (body RdrName)
=> HsStmtContext Name
-> (Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> (HsStmtContext Name
-> [(LStmt Name (Located (body Name)), FreeVars)]
-> RnM ([LStmt Name (Located (body Name))], FreeVars))
-> [LStmt RdrName (Located (body RdrName))]
-> ([Name] -> RnM (thing, FreeVars))
-> RnM (([LStmt Name (Located (body Name))], thing), FreeVars)
rnStmtsWithPostProcessing ctxt rnBody ppStmts stmts thing_inside
= do { ((stmts', thing), fvs) <-
rnStmtsWithFreeVars ctxt rnBody stmts thing_inside
; (pp_stmts, fvs') <- ppStmts ctxt stmts'
; return ((pp_stmts, thing), fvs `plusFV` fvs')
}
postProcessStmtsForApplicativeDo
:: HsStmtContext Name
-> [(ExprLStmt Name, FreeVars)]
-> RnM ([ExprLStmt Name], FreeVars)
postProcessStmtsForApplicativeDo ctxt stmts
= do {
ado_is_on <- xoptM LangExt.ApplicativeDo
; let is_do_expr | DoExpr <- ctxt = True
| otherwise = False
; if ado_is_on && is_do_expr
then rearrangeForApplicativeDo ctxt stmts
else noPostProcessStmts ctxt stmts }
noPostProcessStmts
:: HsStmtContext Name
-> [(LStmt Name (Located (body Name)), FreeVars)]
-> RnM ([LStmt Name (Located (body Name))], FreeVars)
noPostProcessStmts _ stmts = return (map fst stmts, emptyNameSet)
rnStmtsWithFreeVars :: Outputable (body RdrName)
=> HsStmtContext Name
-> (Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> [LStmt RdrName (Located (body RdrName))]
-> ([Name] -> RnM (thing, FreeVars))
-> RnM ( ([(LStmt Name (Located (body Name)), FreeVars)], thing)
, FreeVars)
rnStmtsWithFreeVars ctxt _ [] thing_inside
= do { checkEmptyStmts ctxt
; (thing, fvs) <- thing_inside []
; return (([], thing), fvs) }
rnStmtsWithFreeVars MDoExpr rnBody stmts thing_inside
=
do { ((stmts1, (stmts2, thing)), fvs)
<- rnStmt MDoExpr rnBody (noLoc $ mkRecStmt all_but_last) $ \ _ ->
do { last_stmt' <- checkLastStmt MDoExpr last_stmt
; rnStmt MDoExpr rnBody last_stmt' thing_inside }
; return (((stmts1 ++ stmts2), thing), fvs) }
where
Just (all_but_last, last_stmt) = snocView stmts
rnStmtsWithFreeVars ctxt rnBody (lstmt@(L loc _) : lstmts) thing_inside
| null lstmts
= setSrcSpan loc $
do { lstmt' <- checkLastStmt ctxt lstmt
; rnStmt ctxt rnBody lstmt' thing_inside }
| otherwise
= do { ((stmts1, (stmts2, thing)), fvs)
<- setSrcSpan loc $
do { checkStmt ctxt lstmt
; rnStmt ctxt rnBody lstmt $ \ bndrs1 ->
rnStmtsWithFreeVars ctxt rnBody lstmts $ \ bndrs2 ->
thing_inside (bndrs1 ++ bndrs2) }
; return (((stmts1 ++ stmts2), thing), fvs) }
rnStmt :: Outputable (body RdrName)
=> HsStmtContext Name
-> (Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> LStmt RdrName (Located (body RdrName))
-> ([Name] -> RnM (thing, FreeVars))
-> RnM ( ([(LStmt Name (Located (body Name)), FreeVars)], thing)
, FreeVars)
rnStmt ctxt rnBody (L loc (LastStmt body noret _)) thing_inside
= do { (body', fv_expr) <- rnBody body
; (ret_op, fvs1) <- lookupStmtName ctxt returnMName
; (thing, fvs3) <- thing_inside []
; return (([(L loc (LastStmt body' noret ret_op), fv_expr)], thing),
fv_expr `plusFV` fvs1 `plusFV` fvs3) }
rnStmt ctxt rnBody (L loc (BodyStmt body _ _ _)) thing_inside
= do { (body', fv_expr) <- rnBody body
; (then_op, fvs1) <- lookupStmtName ctxt thenMName
; (guard_op, fvs2) <- if isListCompExpr ctxt
then lookupStmtName ctxt guardMName
else return (noSyntaxExpr, emptyFVs)
; (thing, fvs3) <- thing_inside []
; return (([(L loc (BodyStmt body'
then_op guard_op placeHolderType), fv_expr)], thing),
fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
rnStmt ctxt rnBody (L loc (BindStmt pat body _ _ _)) thing_inside
= do { (body', fv_expr) <- rnBody body
; (bind_op, fvs1) <- lookupStmtName ctxt bindMName
; xMonadFailEnabled <- fmap (xopt LangExt.MonadFailDesugaring) getDynFlags
; let failFunction | xMonadFailEnabled = failMName
| otherwise = failMName_preMFP
; (fail_op, fvs2) <- lookupSyntaxName failFunction
; rnPat (StmtCtxt ctxt) pat $ \ pat' -> do
{ (thing, fvs3) <- thing_inside (collectPatBinders pat')
; return (( [( L loc (BindStmt pat' body' bind_op fail_op PlaceHolder)
, fv_expr )]
, thing),
fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }}
rnStmt _ _ (L loc (LetStmt (L l binds))) thing_inside
= do { rnLocalBindsAndThen binds $ \binds' bind_fvs -> do
{ (thing, fvs) <- thing_inside (collectLocalBinders binds')
; return (([(L loc (LetStmt (L l binds')), bind_fvs)], thing), fvs) } }
rnStmt ctxt rnBody (L loc (RecStmt { recS_stmts = rec_stmts })) thing_inside
= do { (return_op, fvs1) <- lookupStmtName ctxt returnMName
; (mfix_op, fvs2) <- lookupStmtName ctxt mfixName
; (bind_op, fvs3) <- lookupStmtName ctxt bindMName
; let empty_rec_stmt = emptyRecStmtName { recS_ret_fn = return_op
, recS_mfix_fn = mfix_op
, recS_bind_fn = bind_op }
; rnRecStmtsAndThen rnBody rec_stmts $ \ segs -> do
{ let bndrs = nameSetElemsStable $
foldr (unionNameSet . (\(ds,_,_,_) -> ds))
emptyNameSet
segs
; (thing, fvs_later) <- thing_inside bndrs
; let (rec_stmts', fvs) = segmentRecStmts loc ctxt empty_rec_stmt segs fvs_later
; return ( ((zip rec_stmts' (repeat emptyNameSet)), thing)
, fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
rnStmt ctxt _ (L loc (ParStmt segs _ _ _)) thing_inside
= do { (mzip_op, fvs1) <- lookupStmtNamePoly ctxt mzipName
; (bind_op, fvs2) <- lookupStmtName ctxt bindMName
; (return_op, fvs3) <- lookupStmtName ctxt returnMName
; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) return_op segs thing_inside
; return ( ([(L loc (ParStmt segs' mzip_op bind_op placeHolderType), fvs4)], thing)
, fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
rnStmt ctxt _ (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = form
, trS_using = using })) thing_inside
= do {
(using', fvs1) <- rnLExpr using
; ((stmts', (by', used_bndrs, thing)), fvs2)
<- rnStmts (TransStmtCtxt ctxt) rnLExpr stmts $ \ bndrs ->
do { (by', fvs_by) <- mapMaybeFvRn rnLExpr by
; (thing, fvs_thing) <- thing_inside bndrs
; let fvs = fvs_by `plusFV` fvs_thing
used_bndrs = filter (`elemNameSet` fvs) bndrs
; return ((by', used_bndrs, thing), fvs) }
; (return_op, fvs3) <- lookupStmtName ctxt returnMName
; (bind_op, fvs4) <- lookupStmtName ctxt bindMName
; (fmap_op, fvs5) <- case form of
ThenForm -> return (noExpr, emptyFVs)
_ -> lookupStmtNamePoly ctxt fmapName
; let all_fvs = fvs1 `plusFV` fvs2 `plusFV` fvs3
`plusFV` fvs4 `plusFV` fvs5
bndr_map = used_bndrs `zip` used_bndrs
; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
; return (([(L loc (TransStmt { trS_stmts = stmts', trS_bndrs = bndr_map
, trS_by = by', trS_using = using', trS_form = form
, trS_ret = return_op, trS_bind = bind_op
, trS_bind_arg_ty = PlaceHolder
, trS_fmap = fmap_op }), fvs2)], thing), all_fvs) }
rnStmt _ _ (L _ ApplicativeStmt{}) _ =
panic "rnStmt: ApplicativeStmt"
rnParallelStmts :: forall thing. HsStmtContext Name
-> SyntaxExpr Name
-> [ParStmtBlock RdrName RdrName]
-> ([Name] -> RnM (thing, FreeVars))
-> RnM (([ParStmtBlock Name Name], thing), FreeVars)
rnParallelStmts ctxt return_op segs thing_inside
= do { orig_lcl_env <- getLocalRdrEnv
; rn_segs orig_lcl_env [] segs }
where
rn_segs :: LocalRdrEnv
-> [Name] -> [ParStmtBlock RdrName RdrName]
-> RnM (([ParStmtBlock Name Name], thing), FreeVars)
rn_segs _ bndrs_so_far []
= do { let (bndrs', dups) = removeDups cmpByOcc bndrs_so_far
; mapM_ dupErr dups
; (thing, fvs) <- bindLocalNames bndrs' (thing_inside bndrs')
; return (([], thing), fvs) }
rn_segs env bndrs_so_far (ParStmtBlock stmts _ _ : segs)
= do { ((stmts', (used_bndrs, segs', thing)), fvs)
<- rnStmts ctxt rnLExpr stmts $ \ bndrs ->
setLocalRdrEnv env $ do
{ ((segs', thing), fvs) <- rn_segs env (bndrs ++ bndrs_so_far) segs
; let used_bndrs = filter (`elemNameSet` fvs) bndrs
; return ((used_bndrs, segs', thing), fvs) }
; let seg' = ParStmtBlock stmts' used_bndrs return_op
; return ((seg':segs', thing), fvs) }
cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
dupErr vs = addErr (text "Duplicate binding in parallel list comprehension for:"
<+> quotes (ppr (head vs)))
lookupStmtName :: HsStmtContext Name -> Name -> RnM (SyntaxExpr Name, FreeVars)
lookupStmtName ctxt n
| rebindableContext ctxt
= lookupSyntaxName n
| otherwise
= return (mkRnSyntaxExpr n, emptyFVs)
lookupStmtNamePoly :: HsStmtContext Name -> Name -> RnM (HsExpr Name, FreeVars)
lookupStmtNamePoly ctxt name
| rebindableContext ctxt
= do { rebindable_on <- xoptM LangExt.RebindableSyntax
; if rebindable_on
then do { fm <- lookupOccRn (nameRdrName name)
; return (HsVar (noLoc fm), unitFV fm) }
else not_rebindable }
| otherwise
= not_rebindable
where
not_rebindable = return (HsVar (noLoc name), emptyFVs)
rebindableContext :: HsStmtContext Name -> Bool
rebindableContext ctxt = case ctxt of
ListComp -> False
PArrComp -> False
ArrowExpr -> False
PatGuard {} -> False
DoExpr -> True
MDoExpr -> True
MonadComp -> True
GhciStmtCtxt -> True
ParStmtCtxt c -> rebindableContext c
TransStmtCtxt c -> rebindableContext c
type FwdRefs = NameSet
type Segment stmts = (Defs,
Uses,
FwdRefs,
stmts)
rnRecStmtsAndThen :: Outputable (body RdrName) =>
(Located (body RdrName)
-> RnM (Located (body Name), FreeVars))
-> [LStmt RdrName (Located (body RdrName))]
-> ([Segment (LStmt Name (Located (body Name)))]
-> RnM (a, FreeVars))
-> RnM (a, FreeVars)
rnRecStmtsAndThen rnBody s cont
= do {
fix_env <- makeMiniFixityEnv (collectRecStmtsFixities s)
; new_lhs_and_fv <- rn_rec_stmts_lhs fix_env s
; let bound_names = collectLStmtsBinders (map fst new_lhs_and_fv)
implicit_uses = lStmtsImplicits (map fst new_lhs_and_fv)
; bindLocalNamesFV bound_names $
addLocalFixities fix_env bound_names $ do
{ segs <- rn_rec_stmts rnBody bound_names new_lhs_and_fv
; (res, fvs) <- cont segs
; warnUnusedLocalBinds bound_names (fvs `unionNameSet` implicit_uses)
; return (res, fvs) }}
collectRecStmtsFixities :: [LStmtLR RdrName RdrName body] -> [LFixitySig RdrName]
collectRecStmtsFixities l =
foldr (\ s -> \acc -> case s of
(L _ (LetStmt (L _ (HsValBinds (ValBindsIn _ sigs))))) ->
foldr (\ sig -> \ acc -> case sig of
(L loc (FixSig s)) -> (L loc s) : acc
_ -> acc) acc sigs
_ -> acc) [] l
rn_rec_stmt_lhs :: Outputable body => MiniFixityEnv
-> LStmt RdrName body
-> RnM [(LStmtLR Name RdrName body, FreeVars)]
rn_rec_stmt_lhs _ (L loc (BodyStmt body a b c))
= return [(L loc (BodyStmt body a b c), emptyFVs)]
rn_rec_stmt_lhs _ (L loc (LastStmt body noret a))
= return [(L loc (LastStmt body noret a), emptyFVs)]
rn_rec_stmt_lhs fix_env (L loc (BindStmt pat body a b t))
= do
(pat', fv_pat) <- rnBindPat (localRecNameMaker fix_env) pat
return [(L loc (BindStmt pat' body a b t),
fv_pat)]
rn_rec_stmt_lhs _ (L _ (LetStmt (L _ binds@(HsIPBinds _))))
= failWith (badIpBinds (text "an mdo expression") binds)
rn_rec_stmt_lhs fix_env (L loc (LetStmt (L l(HsValBinds binds))))
= do (_bound_names, binds') <- rnLocalValBindsLHS fix_env binds
return [(L loc (LetStmt (L l (HsValBinds binds'))),
emptyFVs
)]
rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts }))
= rn_rec_stmts_lhs fix_env stmts
rn_rec_stmt_lhs _ stmt@(L _ (ParStmt {}))
= pprPanic "rn_rec_stmt" (ppr stmt)
rn_rec_stmt_lhs _ stmt@(L _ (TransStmt {}))
= pprPanic "rn_rec_stmt" (ppr stmt)
rn_rec_stmt_lhs _ stmt@(L _ (ApplicativeStmt {}))
= pprPanic "rn_rec_stmt" (ppr stmt)
rn_rec_stmt_lhs _ (L _ (LetStmt (L _ EmptyLocalBinds)))
= panic "rn_rec_stmt LetStmt EmptyLocalBinds"
rn_rec_stmts_lhs :: Outputable body => MiniFixityEnv
-> [LStmt RdrName body]
-> RnM [(LStmtLR Name RdrName body, FreeVars)]
rn_rec_stmts_lhs fix_env stmts
= do { ls <- concatMapM (rn_rec_stmt_lhs fix_env) stmts
; let boundNames = collectLStmtsBinders (map fst ls)
; checkDupNames boundNames
; return ls }
rn_rec_stmt :: (Outputable (body RdrName)) =>
(Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> [Name]
-> (LStmtLR Name RdrName (Located (body RdrName)), FreeVars)
-> RnM [Segment (LStmt Name (Located (body Name)))]
rn_rec_stmt rnBody _ (L loc (LastStmt body noret _), _)
= do { (body', fv_expr) <- rnBody body
; (ret_op, fvs1) <- lookupSyntaxName returnMName
; return [(emptyNameSet, fv_expr `plusFV` fvs1, emptyNameSet,
L loc (LastStmt body' noret ret_op))] }
rn_rec_stmt rnBody _ (L loc (BodyStmt body _ _ _), _)
= do { (body', fvs) <- rnBody body
; (then_op, fvs1) <- lookupSyntaxName thenMName
; return [(emptyNameSet, fvs `plusFV` fvs1, emptyNameSet,
L loc (BodyStmt body' then_op noSyntaxExpr placeHolderType))] }
rn_rec_stmt rnBody _ (L loc (BindStmt pat' body _ _ _), fv_pat)
= do { (body', fv_expr) <- rnBody body
; (bind_op, fvs1) <- lookupSyntaxName bindMName
; xMonadFailEnabled <- fmap (xopt LangExt.MonadFailDesugaring) getDynFlags
; let failFunction | xMonadFailEnabled = failMName
| otherwise = failMName_preMFP
; (fail_op, fvs2) <- lookupSyntaxName failFunction
; let bndrs = mkNameSet (collectPatBinders pat')
fvs = fv_expr `plusFV` fv_pat `plusFV` fvs1 `plusFV` fvs2
; return [(bndrs, fvs, bndrs `intersectNameSet` fvs,
L loc (BindStmt pat' body' bind_op fail_op PlaceHolder))] }
rn_rec_stmt _ _ (L _ (LetStmt (L _ binds@(HsIPBinds _))), _)
= failWith (badIpBinds (text "an mdo expression") binds)
rn_rec_stmt _ all_bndrs (L loc (LetStmt (L l (HsValBinds binds'))), _)
= do { (binds', du_binds) <- rnLocalValBindsRHS (mkNameSet all_bndrs) binds'
; let fvs = allUses du_binds
; return [(duDefs du_binds, fvs, emptyNameSet,
L loc (LetStmt (L l (HsValBinds binds'))))] }
rn_rec_stmt _ _ stmt@(L _ (RecStmt {}), _)
= pprPanic "rn_rec_stmt: RecStmt" (ppr stmt)
rn_rec_stmt _ _ stmt@(L _ (ParStmt {}), _)
= pprPanic "rn_rec_stmt: ParStmt" (ppr stmt)
rn_rec_stmt _ _ stmt@(L _ (TransStmt {}), _)
= pprPanic "rn_rec_stmt: TransStmt" (ppr stmt)
rn_rec_stmt _ _ (L _ (LetStmt (L _ EmptyLocalBinds)), _)
= panic "rn_rec_stmt: LetStmt EmptyLocalBinds"
rn_rec_stmt _ _ stmt@(L _ (ApplicativeStmt {}), _)
= pprPanic "rn_rec_stmt: ApplicativeStmt" (ppr stmt)
rn_rec_stmts :: Outputable (body RdrName) =>
(Located (body RdrName) -> RnM (Located (body Name), FreeVars))
-> [Name]
-> [(LStmtLR Name RdrName (Located (body RdrName)), FreeVars)]
-> RnM [Segment (LStmt Name (Located (body Name)))]
rn_rec_stmts rnBody bndrs stmts
= do { segs_s <- mapM (rn_rec_stmt rnBody bndrs) stmts
; return (concat segs_s) }
segmentRecStmts :: SrcSpan -> HsStmtContext Name
-> Stmt Name body
-> [Segment (LStmt Name body)] -> FreeVars
-> ([LStmt Name body], FreeVars)
segmentRecStmts loc ctxt empty_rec_stmt segs fvs_later
| null segs
= ([], fvs_later)
| MDoExpr <- ctxt
= segsToStmts empty_rec_stmt grouped_segs fvs_later
| otherwise
= ([ L loc $
empty_rec_stmt { recS_stmts = ss
, recS_later_ids = nameSetElemsStable
(defs `intersectNameSet` fvs_later)
, recS_rec_ids = nameSetElemsStable
(defs `intersectNameSet` uses) }]
, uses `plusFV` fvs_later)
where
(defs_s, uses_s, _, ss) = unzip4 segs
defs = plusFVs defs_s
uses = plusFVs uses_s
segs_w_fwd_refs = addFwdRefs segs
grouped_segs = glomSegments ctxt segs_w_fwd_refs
addFwdRefs :: [Segment a] -> [Segment a]
addFwdRefs segs
= fst (foldr mk_seg ([], emptyNameSet) segs)
where
mk_seg (defs, uses, fwds, stmts) (segs, later_defs)
= (new_seg : segs, all_defs)
where
new_seg = (defs, uses, new_fwds, stmts)
all_defs = later_defs `unionNameSet` defs
new_fwds = fwds `unionNameSet` (uses `intersectNameSet` later_defs)
glomSegments :: HsStmtContext Name
-> [Segment (LStmt Name body)]
-> [Segment [LStmt Name body]]
glomSegments _ [] = []
glomSegments ctxt ((defs,uses,fwds,stmt) : segs)
= (seg_defs, seg_uses, seg_fwds, seg_stmts) : others
where
segs' = glomSegments ctxt segs
(extras, others) = grab uses segs'
(ds, us, fs, ss) = unzip4 extras
seg_defs = plusFVs ds `plusFV` defs
seg_uses = plusFVs us `plusFV` uses
seg_fwds = plusFVs fs `plusFV` fwds
seg_stmts = stmt : concat ss
grab :: NameSet
-> [Segment a]
-> ([Segment a],
[Segment a])
grab uses dus
= (reverse yeses, reverse noes)
where
(noes, yeses) = span not_needed (reverse dus)
not_needed (defs,_,_,_) = not (intersectsNameSet defs uses)
segsToStmts :: Stmt Name body
-> [Segment [LStmt Name body]]
-> FreeVars
-> ([LStmt Name body], FreeVars)
segsToStmts _ [] fvs_later = ([], fvs_later)
segsToStmts empty_rec_stmt ((defs, uses, fwds, ss) : segs) fvs_later
= ASSERT( not (null ss) )
(new_stmt : later_stmts, later_uses `plusFV` uses)
where
(later_stmts, later_uses) = segsToStmts empty_rec_stmt segs fvs_later
new_stmt | non_rec = head ss
| otherwise = L (getLoc (head ss)) rec_stmt
rec_stmt = empty_rec_stmt { recS_stmts = ss
, recS_later_ids = nameSetElemsStable used_later
, recS_rec_ids = nameSetElemsStable fwds }
non_rec = isSingleton ss && isEmptyNameSet fwds
used_later = defs `intersectNameSet` later_uses
data MonadNames = MonadNames { return_name, pure_name :: Name }
rearrangeForApplicativeDo
:: HsStmtContext Name
-> [(ExprLStmt Name, FreeVars)]
-> RnM ([ExprLStmt Name], FreeVars)
rearrangeForApplicativeDo _ [] = return ([], emptyNameSet)
rearrangeForApplicativeDo _ [(one,_)] = return ([one], emptyNameSet)
rearrangeForApplicativeDo ctxt stmts0 = do
optimal_ado <- goptM Opt_OptimalApplicativeDo
let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts
| otherwise = mkStmtTreeHeuristic stmts
return_name <- lookupSyntaxName' returnMName
pure_name <- lookupSyntaxName' pureAName
let monad_names = MonadNames { return_name = return_name
, pure_name = pure_name }
stmtTreeToStmts monad_names ctxt stmt_tree [last] last_fvs
where
(stmts,(last,last_fvs)) = findLast stmts0
findLast [] = error "findLast"
findLast [last] = ([],last)
findLast (x:xs) = (x:rest,last) where (rest,last) = findLast xs
data StmtTree a
= StmtTreeOne a
| StmtTreeBind (StmtTree a) (StmtTree a)
| StmtTreeApplicative [StmtTree a]
flattenStmtTree :: StmtTree a -> [a]
flattenStmtTree t = go t []
where
go (StmtTreeOne a) as = a : as
go (StmtTreeBind l r) as = go l (go r as)
go (StmtTreeApplicative ts) as = foldr go as ts
type ExprStmtTree = StmtTree (ExprLStmt Name, FreeVars)
type Cost = Int
mkStmtTreeHeuristic :: [(ExprLStmt Name, FreeVars)] -> ExprStmtTree
mkStmtTreeHeuristic [one] = StmtTreeOne one
mkStmtTreeHeuristic stmts =
case segments stmts of
[one] -> split one
segs -> StmtTreeApplicative (map split segs)
where
split [one] = StmtTreeOne one
split stmts =
StmtTreeBind (mkStmtTreeHeuristic before) (mkStmtTreeHeuristic after)
where (before, after) = splitSegment stmts
mkStmtTreeOptimal :: [(ExprLStmt Name, FreeVars)] -> ExprStmtTree
mkStmtTreeOptimal stmts =
ASSERT(not (null stmts))
fst (arr ! (0,n))
where
n = length stmts 1
stmt_arr = listArray (0,n) stmts
arr :: Array (Int,Int) (ExprStmtTree, Cost)
arr = array ((0,0),(n,n))
[ ((lo,hi), tree lo hi)
| lo <- [0..n]
, hi <- [lo..n] ]
tree lo hi
| hi == lo = (StmtTreeOne (stmt_arr ! lo), 1)
| otherwise =
case segments [ stmt_arr ! i | i <- [lo..hi] ] of
[] -> panic "mkStmtTree"
[_one] -> split lo hi
segs -> (StmtTreeApplicative trees, maximum costs)
where
bounds = scanl (\(_,hi) a -> (hi+1, hi + length a)) (0,lo1) segs
(trees,costs) = unzip (map (uncurry split) (tail bounds))
split :: Int -> Int -> (ExprStmtTree, Cost)
split lo hi
| hi == lo = (StmtTreeOne (stmt_arr ! lo), 1)
| otherwise = (StmtTreeBind before after, c1+c2)
where
((before,c1),(after,c2))
| hi lo == 1
= ((StmtTreeOne (stmt_arr ! lo), 1),
(StmtTreeOne (stmt_arr ! hi), 1))
| left_cost < right_cost
= ((left,left_cost), (StmtTreeOne (stmt_arr ! hi), 1))
| otherwise
= ((StmtTreeOne (stmt_arr ! lo), 1), (right,right_cost))
| otherwise = minimumBy (comparing cost) alternatives
where
(left, left_cost) = arr ! (lo,hi1)
(right, right_cost) = arr ! (lo+1,hi)
cost ((_,c1),(_,c2)) = c1 + c2
alternatives = [ (arr ! (lo,k), arr ! (k+1,hi))
| k <- [lo .. hi1] ]
stmtTreeToStmts
:: MonadNames
-> HsStmtContext Name
-> ExprStmtTree
-> [ExprLStmt Name]
-> FreeVars
-> RnM ( [ExprLStmt Name]
, FreeVars )
stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
tail _tail_fvs
| isIrrefutableHsPat pat, (False,tail') <- needJoin monad_names tail
= mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
return (s : tail, emptyNameSet)
stmtTreeToStmts monad_names ctxt (StmtTreeBind before after) tail tail_fvs = do
(stmts1, fvs1) <- stmtTreeToStmts monad_names ctxt after tail tail_fvs
let tail1_fvs = unionNameSets (tail_fvs : map snd (flattenStmtTree after))
(stmts2, fvs2) <- stmtTreeToStmts monad_names ctxt before stmts1 tail1_fvs
return (stmts2, fvs1 `plusFV` fvs2)
stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
pairs <- mapM (stmtTreeArg ctxt tail_fvs) trees
let (stmts', fvss) = unzip pairs
let (need_join, tail') = needJoin monad_names tail
(stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
return (stmts, unionNameSets (fvs:fvss))
where
stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BindStmt pat exp _ _ _), _)) =
return (ApplicativeArgOne pat exp, emptyFVs)
stmtTreeArg ctxt tail_fvs tree = do
let stmts = flattenStmtTree tree
pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
`intersectNameSet` tail_fvs
pvars = nameSetElemsStable pvarset
pat = mkBigLHsVarPatTup pvars
tup = mkBigLHsVarTup pvars
(stmts',fvs2) <- stmtTreeToStmts monad_names ctxt tree [] pvarset
(mb_ret, fvs1) <-
if | L _ ApplicativeStmt{} <- last stmts' ->
return (unLoc tup, emptyNameSet)
| otherwise -> do
(ret,fvs) <- lookupStmtNamePoly ctxt returnMName
return (HsApp (noLoc ret) tup, fvs)
return ( ApplicativeArgMany stmts' mb_ret pat
, fvs1 `plusFV` fvs2)
segments
:: [(ExprLStmt Name, FreeVars)]
-> [[(ExprLStmt Name, FreeVars)]]
segments stmts = map fst $ merge $ reverse $ map reverse $ walk (reverse stmts)
where
allvars = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
merge [] = []
merge (seg : segs)
= case rest of
[] -> [(seg,all_lets)]
((s,s_lets):ss) | all_lets || s_lets
-> (seg ++ s, all_lets && s_lets) : ss
_otherwise -> (seg,all_lets) : rest
where
rest = merge segs
all_lets = all (isLetStmt . fst) seg
walk :: [(ExprLStmt Name, FreeVars)] -> [[(ExprLStmt Name, FreeVars)]]
walk [] = []
walk ((stmt,fvs) : stmts) = ((stmt,fvs) : seg) : walk rest
where (seg,rest) = chunter fvs' stmts
(_, fvs') = stmtRefs stmt fvs
chunter _ [] = ([], [])
chunter vars ((stmt,fvs) : rest)
| not (isEmptyNameSet vars)
= ((stmt,fvs) : chunk, rest')
where (chunk,rest') = chunter vars' rest
(pvars, evars) = stmtRefs stmt fvs
vars' = (vars `minusNameSet` pvars) `unionNameSet` evars
chunter _ rest = ([], rest)
stmtRefs stmt fvs
| isLetStmt stmt = (pvars, fvs' `minusNameSet` pvars)
| otherwise = (pvars, fvs')
where fvs' = fvs `intersectNameSet` allvars
pvars = mkNameSet (collectStmtBinders (unLoc stmt))
isLetStmt :: LStmt a b -> Bool
isLetStmt (L _ LetStmt{}) = True
isLetStmt _ = False
splitSegment
:: [(ExprLStmt Name, FreeVars)]
-> ( [(ExprLStmt Name, FreeVars)]
, [(ExprLStmt Name, FreeVars)] )
splitSegment [one,two] = ([one],[two])
splitSegment stmts
| Just (lets,binds,rest) <- slurpIndependentStmts stmts
= if not (null lets)
then (lets, binds++rest)
else (lets++binds, rest)
| otherwise
= case stmts of
(x:xs) -> ([x],xs)
_other -> (stmts,[])
slurpIndependentStmts
:: [(LStmt Name (Located (body Name)), FreeVars)]
-> Maybe ( [(LStmt Name (Located (body Name)), FreeVars)]
, [(LStmt Name (Located (body Name)), FreeVars)]
, [(LStmt Name (Located (body Name)), FreeVars)] )
slurpIndependentStmts stmts = go [] [] emptyNameSet stmts
where
go lets indep bndrs ((L loc (BindStmt pat body bind_op fail_op ty), fvs) : rest)
| isEmptyNameSet (bndrs `intersectNameSet` fvs)
= go lets ((L loc (BindStmt pat body bind_op fail_op ty), fvs) : indep)
bndrs' rest
where bndrs' = bndrs `unionNameSet` mkNameSet (collectPatBinders pat)
go lets indep bndrs ((L loc (LetStmt binds), fvs) : rest)
| isEmptyNameSet (bndrs `intersectNameSet` fvs)
= go ((L loc (LetStmt binds), fvs) : lets) indep bndrs rest
go _ [] _ _ = Nothing
go _ [_] _ _ = Nothing
go lets indep _ stmts = Just (reverse lets, reverse indep, stmts)
mkApplicativeStmt
:: HsStmtContext Name
-> [ApplicativeArg Name Name]
-> Bool
-> [ExprLStmt Name]
-> RnM ([ExprLStmt Name], FreeVars)
mkApplicativeStmt ctxt args need_join body_stmts
= do { (fmap_op, fvs1) <- lookupStmtName ctxt fmapName
; (ap_op, fvs2) <- lookupStmtName ctxt apAName
; (mb_join, fvs3) <-
if need_join then
do { (join_op, fvs) <- lookupStmtName ctxt joinMName
; return (Just join_op, fvs) }
else
return (Nothing, emptyNameSet)
; let applicative_stmt = noLoc $ ApplicativeStmt
(zip (fmap_op : repeat ap_op) args)
mb_join
placeHolderType
; return ( applicative_stmt : body_stmts
, fvs1 `plusFV` fvs2 `plusFV` fvs3) }
needJoin :: MonadNames
-> [ExprLStmt Name]
-> (Bool, [ExprLStmt Name])
needJoin _monad_names [] = (False, [])
needJoin monad_names [L loc (LastStmt e _ t)]
| Just arg <- isReturnApp monad_names e =
(False, [L loc (LastStmt arg True t)])
needJoin _monad_names stmts = (True, stmts)
isReturnApp :: MonadNames
-> LHsExpr Name
-> Maybe (LHsExpr Name)
isReturnApp monad_names (L _ (HsPar expr)) = isReturnApp monad_names expr
isReturnApp monad_names (L _ e) = case e of
OpApp l op _ r | is_return l, is_dollar op -> Just r
HsApp f arg | is_return f -> Just arg
_otherwise -> Nothing
where
is_var f (L _ (HsPar e)) = is_var f e
is_var f (L _ (HsAppType e _)) = is_var f e
is_var f (L _ (HsVar (L _ r))) = f r
is_var _ _ = False
is_return = is_var (\n -> n == return_name monad_names
|| n == pure_name monad_names)
is_dollar = is_var (`hasKey` dollarIdKey)
checkEmptyStmts :: HsStmtContext Name -> RnM ()
checkEmptyStmts ctxt
= unless (okEmpty ctxt) (addErr (emptyErr ctxt))
okEmpty :: HsStmtContext a -> Bool
okEmpty (PatGuard {}) = True
okEmpty _ = False
emptyErr :: HsStmtContext Name -> SDoc
emptyErr (ParStmtCtxt {}) = text "Empty statement group in parallel comprehension"
emptyErr (TransStmtCtxt {}) = text "Empty statement group preceding 'group' or 'then'"
emptyErr ctxt = text "Empty" <+> pprStmtContext ctxt
checkLastStmt :: Outputable (body RdrName) => HsStmtContext Name
-> LStmt RdrName (Located (body RdrName))
-> RnM (LStmt RdrName (Located (body RdrName)))
checkLastStmt ctxt lstmt@(L loc stmt)
= case ctxt of
ListComp -> check_comp
MonadComp -> check_comp
PArrComp -> check_comp
ArrowExpr -> check_do
DoExpr -> check_do
MDoExpr -> check_do
_ -> check_other
where
check_do
= case stmt of
BodyStmt e _ _ _ -> return (L loc (mkLastStmt e))
LastStmt {} -> return lstmt
_ -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt }
last_error = (text "The last statement in" <+> pprAStmtContext ctxt
<+> text "must be an expression")
check_comp
= case stmt of
LastStmt {} -> return lstmt
_ -> pprPanic "checkLastStmt" (ppr lstmt)
check_other
= do { checkStmt ctxt lstmt; return lstmt }
checkStmt :: HsStmtContext Name
-> LStmt RdrName (Located (body RdrName))
-> RnM ()
checkStmt ctxt (L _ stmt)
= do { dflags <- getDynFlags
; case okStmt dflags ctxt stmt of
IsValid -> return ()
NotValid extra -> addErr (msg $$ extra) }
where
msg = sep [ text "Unexpected" <+> pprStmtCat stmt <+> ptext (sLit "statement")
, text "in" <+> pprAStmtContext ctxt ]
pprStmtCat :: Stmt a body -> SDoc
pprStmtCat (TransStmt {}) = text "transform"
pprStmtCat (LastStmt {}) = text "return expression"
pprStmtCat (BodyStmt {}) = text "body"
pprStmtCat (BindStmt {}) = text "binding"
pprStmtCat (LetStmt {}) = text "let"
pprStmtCat (RecStmt {}) = text "rec"
pprStmtCat (ParStmt {}) = text "parallel"
pprStmtCat (ApplicativeStmt {}) = panic "pprStmtCat: ApplicativeStmt"
emptyInvalid :: Validity
emptyInvalid = NotValid Outputable.empty
okStmt, okDoStmt, okCompStmt, okParStmt, okPArrStmt
:: DynFlags -> HsStmtContext Name
-> Stmt RdrName (Located (body RdrName)) -> Validity
okStmt dflags ctxt stmt
= case ctxt of
PatGuard {} -> okPatGuardStmt stmt
ParStmtCtxt ctxt -> okParStmt dflags ctxt stmt
DoExpr -> okDoStmt dflags ctxt stmt
MDoExpr -> okDoStmt dflags ctxt stmt
ArrowExpr -> okDoStmt dflags ctxt stmt
GhciStmtCtxt -> okDoStmt dflags ctxt stmt
ListComp -> okCompStmt dflags ctxt stmt
MonadComp -> okCompStmt dflags ctxt stmt
PArrComp -> okPArrStmt dflags ctxt stmt
TransStmtCtxt ctxt -> okStmt dflags ctxt stmt
okPatGuardStmt :: Stmt RdrName (Located (body RdrName)) -> Validity
okPatGuardStmt stmt
= case stmt of
BodyStmt {} -> IsValid
BindStmt {} -> IsValid
LetStmt {} -> IsValid
_ -> emptyInvalid
okParStmt dflags ctxt stmt
= case stmt of
LetStmt (L _ (HsIPBinds {})) -> emptyInvalid
_ -> okStmt dflags ctxt stmt
okDoStmt dflags ctxt stmt
= case stmt of
RecStmt {}
| LangExt.RecursiveDo `xopt` dflags -> IsValid
| ArrowExpr <- ctxt -> IsValid
| otherwise -> NotValid (text "Use RecursiveDo")
BindStmt {} -> IsValid
LetStmt {} -> IsValid
BodyStmt {} -> IsValid
_ -> emptyInvalid
okCompStmt dflags _ stmt
= case stmt of
BindStmt {} -> IsValid
LetStmt {} -> IsValid
BodyStmt {} -> IsValid
ParStmt {}
| LangExt.ParallelListComp `xopt` dflags -> IsValid
| otherwise -> NotValid (text "Use ParallelListComp")
TransStmt {}
| LangExt.TransformListComp `xopt` dflags -> IsValid
| otherwise -> NotValid (text "Use TransformListComp")
RecStmt {} -> emptyInvalid
LastStmt {} -> emptyInvalid
ApplicativeStmt {} -> emptyInvalid
okPArrStmt dflags _ stmt
= case stmt of
BindStmt {} -> IsValid
LetStmt {} -> IsValid
BodyStmt {} -> IsValid
ParStmt {}
| LangExt.ParallelListComp `xopt` dflags -> IsValid
| otherwise -> NotValid (text "Use ParallelListComp")
TransStmt {} -> emptyInvalid
RecStmt {} -> emptyInvalid
LastStmt {} -> emptyInvalid
ApplicativeStmt {} -> emptyInvalid
checkTupleSection :: [LHsTupArg RdrName] -> RnM ()
checkTupleSection args
= do { tuple_section <- xoptM LangExt.TupleSections
; checkErr (all tupArgPresent args || tuple_section) msg }
where
msg = text "Illegal tuple section: use TupleSections"
sectionErr :: HsExpr RdrName -> SDoc
sectionErr expr
= hang (text "A section must be enclosed in parentheses")
2 (text "thus:" <+> (parens (ppr expr)))
patSynErr :: HsExpr RdrName -> SDoc -> RnM (HsExpr Name, FreeVars)
patSynErr e explanation = do { addErr (sep [text "Pattern syntax in expression context:",
nest 4 (ppr e)] $$
explanation)
; return (EWildPat, emptyFVs) }
badIpBinds :: Outputable a => SDoc -> a -> SDoc
badIpBinds what binds
= hang (text "Implicit-parameter bindings illegal in" <+> what)
2 (ppr binds)