module Check (
checkSingle, checkMatches, dsPmWarn,
computeNoGuards,
isAnyPmCheckEnabled,
warnManyGuards,
maximum_failing_guards,
genCaseTmCs1, genCaseTmCs2
) where
#include "HsVersions.h"
import TmOracle
import DynFlags
import HsSyn
import TcHsSyn
import Id
import ConLike
import DataCon
import Name
import TysWiredIn
import TyCon
import SrcLoc
import Util
import Outputable
import FastString
import DsMonad
import TcSimplify
import TcType
import Bag
import ErrUtils
import MonadUtils
import Var
import Type
import UniqSupply
import DsGRHSs
import Data.List
import Data.Maybe
import Control.Monad
import Coercion
import TcEvidence
type PmM a = DsM a
data PmConstraint = TmConstraint PmExpr PmExpr
| TyConstraint [EvVar]
| BtConstraint Id
data PatTy = PAT | VA
data PmPat :: PatTy -> * where
PmCon :: { pm_con_con :: DataCon
, pm_con_arg_tys :: [Type]
, pm_con_tvs :: [TyVar]
, pm_con_dicts :: [EvVar]
, pm_con_args :: [PmPat t] } -> PmPat t
PmVar :: { pm_var_id :: Id } -> PmPat t
PmLit :: { pm_lit_lit :: PmLit } -> PmPat t
PmNLit :: { pm_lit_id :: Id
, pm_lit_not :: [PmLit] } -> PmPat 'VA
PmGrd :: { pm_grd_pv :: PatVec
, pm_grd_expr :: PmExpr } -> PmPat 'PAT
type Pattern = PmPat 'PAT
type ValAbs = PmPat 'VA
type PatVec = [Pattern]
type ValVecAbs = [ValAbs]
data ValSetAbs
= Empty
| Union ValSetAbs ValSetAbs
| Singleton
| Constraint [PmConstraint] ValSetAbs
| Cons ValAbs ValSetAbs
type PmResult = ( [[LPat Id]]
, [[LPat Id]]
, [([PmExpr], [ComplexEq])] )
checkSingle :: Id -> Pat Id -> DsM PmResult
checkSingle var p = do
let lp = [noLoc p]
vec <- liftUs (translatePat p)
vsa <- initial_uncovered [var]
(c,d,us') <- patVectProc False (vec,[]) vsa
us <- pruneVSA us'
return $ case (c,d) of
(True, _) -> ([], [], us)
(False, True) -> ([], [lp], us)
(False, False) -> ([lp], [], us)
checkMatches :: Bool -> [Id] -> [LMatch Id (LHsExpr Id)] -> DsM PmResult
checkMatches oversimplify vars matches
| null matches = return ([],[],[])
| otherwise = do
missing <- initial_uncovered vars
(rs,is,us) <- go matches missing
return (map hsLMatchPats rs, map hsLMatchPats is, us)
where
go [] missing = do
missing' <- pruneVSA missing
return ([], [], missing')
go (m:ms) missing = do
clause <- liftUs (translateMatch m)
(c, d, us ) <- patVectProc oversimplify clause missing
(rs, is, us') <- go ms us
return $ case (c,d) of
(True, _) -> ( rs, is, us')
(False, True) -> ( rs, m:is, us')
(False, False) -> (m:rs, is, us')
initial_uncovered :: [Id] -> DsM ValSetAbs
initial_uncovered vars = do
cs <- getCsPmM
let vsa = foldr Cons Singleton (map PmVar vars)
return $ if null cs then vsa
else mkConstraint cs vsa
getCsPmM :: DsM [PmConstraint]
getCsPmM = do
ty_cs <- bagToList <$> getDictsDs
tm_cs <- map simpleToTmCs . bagToList <$> getTmCsDs
return $ if null ty_cs
then tm_cs
else TyConstraint ty_cs : tm_cs
where
simpleToTmCs :: (Id, PmExpr) -> PmConstraint
simpleToTmCs (x,e) = TmConstraint (PmExprVar x) e
noFailingGuards :: [(PatVec,[PatVec])] -> Int
noFailingGuards clauses = sum [ countPatVecs gvs | (_, gvs) <- clauses ]
where
countPatVec gv = length [ () | p <- gv, not (cantFailPattern p) ]
countPatVecs gvs = sum [ countPatVec gv | gv <- gvs ]
computeNoGuards :: [LMatch Id (LHsExpr Id)] -> PmM Int
computeNoGuards matches = do
matches' <- mapM (liftUs . translateMatch) matches
return (noFailingGuards matches')
maximum_failing_guards :: Int
maximum_failing_guards = 20
nullaryConPattern :: DataCon -> Pattern
nullaryConPattern con =
PmCon { pm_con_con = con, pm_con_arg_tys = []
, pm_con_tvs = [], pm_con_dicts = [], pm_con_args = [] }
truePattern :: Pattern
truePattern = nullaryConPattern trueDataCon
fake_pat :: Pattern
fake_pat = PmGrd { pm_grd_pv = [truePattern]
, pm_grd_expr = PmExprOther EWildPat }
vanillaConPattern :: DataCon -> [Type] -> PatVec -> Pattern
vanillaConPattern con arg_tys args =
PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_tvs = [], pm_con_dicts = [], pm_con_args = args }
nilPattern :: Type -> Pattern
nilPattern ty =
PmCon { pm_con_con = nilDataCon, pm_con_arg_tys = [ty]
, pm_con_tvs = [], pm_con_dicts = []
, pm_con_args = [] }
mkListPatVec :: Type -> PatVec -> PatVec -> PatVec
mkListPatVec ty xs ys = [PmCon { pm_con_con = consDataCon
, pm_con_arg_tys = [ty]
, pm_con_tvs = [], pm_con_dicts = []
, pm_con_args = xs++ys }]
mkLitPattern :: HsLit -> Pattern
mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
translatePat :: Pat Id -> UniqSM PatVec
translatePat pat = case pat of
WildPat ty -> mkPmVarsSM [ty]
VarPat id -> return [PmVar (unLoc id)]
ParPat p -> translatePat (unLoc p)
LazyPat _ -> mkPmVarsSM [hsPatType pat]
BangPat p -> translatePat (unLoc p)
AsPat lid p -> do
ps <- translatePat (unLoc p)
let [e] = map valAbsToPmExpr (coercePatVec ps)
g = PmGrd [PmVar (unLoc lid)] e
return (ps ++ [g])
SigPatOut p _ty -> translatePat (unLoc p)
CoPat wrapper p ty
| isIdHsWrapper wrapper -> translatePat p
| WpCast co <- wrapper, isReflexiveCo co -> translatePat p
| otherwise -> do
ps <- translatePat p
(xp,xe) <- mkPmId2FormsSM ty
let g = mkGuard ps (HsWrap wrapper (unLoc xe))
return [xp,g]
NPlusKPat (L _ n) k ge minus -> do
(xp, xe) <- mkPmId2FormsSM (idType n)
let ke = L (getLoc k) (HsOverLit (unLoc k))
g1 = mkGuard [truePattern] (OpApp xe (noLoc ge) no_fixity ke)
g2 = mkGuard [PmVar n] (OpApp xe (noLoc minus) no_fixity ke)
return [xp, g1, g2]
ViewPat lexpr lpat arg_ty -> do
ps <- translatePat (unLoc lpat)
case all cantFailPattern ps of
True -> do
(xp,xe) <- mkPmId2FormsSM arg_ty
let g = mkGuard ps (HsApp lexpr xe)
return [xp,g]
False -> do
var <- mkPmVarSM arg_ty
return [var, fake_pat]
ListPat ps ty Nothing -> do
foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec (map unLoc ps)
ListPat lpats elem_ty (Just (pat_ty, _to_list))
| Just e_ty <- splitListTyConApp_maybe pat_ty, elem_ty `eqType` e_ty ->
translatePat (ListPat lpats e_ty Nothing)
| otherwise -> do
var <- mkPmVarSM pat_ty
return [var, fake_pat]
ConPatOut { pat_con = L _ (PatSynCon _) } -> do
var <- mkPmVarSM (hsPatType pat)
return [var,fake_pat]
ConPatOut { pat_con = L _ (RealDataCon con)
, pat_arg_tys = arg_tys
, pat_tvs = ex_tvs
, pat_dicts = dicts
, pat_args = ps } -> do
args <- translateConPatVec arg_tys ex_tvs con ps
return [PmCon { pm_con_con = con
, pm_con_arg_tys = arg_tys
, pm_con_tvs = ex_tvs
, pm_con_dicts = dicts
, pm_con_args = args }]
NPat (L _ ol) mb_neg _eq -> translateNPat ol mb_neg
LitPat lit
| HsString src s <- lit ->
foldr (mkListPatVec charTy) [nilPattern charTy] <$>
translatePatVec (map (LitPat . HsChar src) (unpackFS s))
| otherwise -> return [mkLitPattern lit]
PArrPat ps ty -> do
tidy_ps <- translatePatVec (map unLoc ps)
let fake_con = parrFakeCon (length ps)
return [vanillaConPattern fake_con [ty] (concat tidy_ps)]
TuplePat ps boxity tys -> do
tidy_ps <- translatePatVec (map unLoc ps)
let tuple_con = tupleDataCon boxity (length ps)
return [vanillaConPattern tuple_con tys (concat tidy_ps)]
ConPatIn {} -> panic "Check.translatePat: ConPatIn"
SplicePat {} -> panic "Check.translatePat: SplicePat"
SigPatIn {} -> panic "Check.translatePat: SigPatIn"
translateNPat :: HsOverLit Id -> Maybe (SyntaxExpr Id) -> UniqSM PatVec
translateNPat (OverLit val False _ ty) mb_neg
| isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
= translatePat (LitPat (HsString src s))
| isIntTy ty, HsIntegral src i <- val
= translatePat (mk_num_lit HsInt src i)
| isWordTy ty, HsIntegral src i <- val
= translatePat (mk_num_lit HsWordPrim src i)
where
mk_num_lit c src i = LitPat $ case mb_neg of
Nothing -> c src i
Just _ -> c src (i)
translateNPat ol mb_neg
= return [PmLit { pm_lit_lit = PmOLit (isJust mb_neg) ol }]
translatePatVec :: [Pat Id] -> UniqSM [PatVec]
translatePatVec pats = mapM translatePat pats
translateConPatVec :: [Type] -> [TyVar]
-> DataCon -> HsConPatDetails Id -> UniqSM PatVec
translateConPatVec _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec (map unLoc ps)
translateConPatVec _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec (map unLoc [p1,p2])
translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| null fs = mkPmVarsSM arg_tys
| null orig_lbls = ASSERT(null matched_lbls) mkPmVarsSM arg_tys
| matched_lbls `subsetOf` orig_lbls
= ASSERT(length orig_lbls == length arg_tys)
let translateOne (lbl, ty) = case lookup lbl matched_pats of
Just p -> translatePat p
Nothing -> mkPmVarsSM [ty]
in concatMapM translateOne (zip orig_lbls arg_tys)
| otherwise = do
arg_var_pats <- mkPmVarsSM arg_tys
translated_pats <- forM matched_pats $ \(x,pat) -> do
pvec <- translatePat pat
return (x, pvec)
let zipped = zip orig_lbls [ x | PmVar x <- arg_var_pats ]
guards = map (\(name,pvec) -> case lookup name zipped of
Just x -> PmGrd pvec (PmExprVar x)
Nothing -> panic "translateConPatVec: lookup")
translated_pats
return (arg_var_pats ++ guards)
where
arg_tys = dataConInstOrigArgTys c (univ_tys ++ mkTyVarTys ex_tvs)
orig_lbls = map flSelector $ dataConFieldLabels c
matched_pats = [ (getName (unLoc (hsRecFieldId x)), unLoc (hsRecFieldArg x))
| L _ x <- fs]
matched_lbls = [ name | (name, _pat) <- matched_pats ]
subsetOf :: Eq a => [a] -> [a] -> Bool
subsetOf [] _ = True
subsetOf (_:_) [] = False
subsetOf (x:xs) (y:ys)
| x == y = subsetOf xs ys
| otherwise = subsetOf (x:xs) ys
translateMatch :: LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec])
translateMatch (L _ (Match _ lpats _ grhss)) = do
pats' <- concat <$> translatePatVec pats
guards' <- mapM translateGuards guards
return (pats', guards')
where
extractGuards :: LGRHS Id (LHsExpr Id) -> [GuardStmt Id]
extractGuards (L _ (GRHS gs _)) = map unLoc gs
pats = map unLoc lpats
guards = map extractGuards (grhssGRHSs grhss)
translateGuards :: [GuardStmt Id] -> UniqSM PatVec
translateGuards guards = do
all_guards <- concat <$> mapM translateGuard guards
return (replace_unhandled all_guards)
where
replace_unhandled :: PatVec -> PatVec
replace_unhandled gv
| any_unhandled gv = fake_pat : [ p | p <- gv, shouldKeep p ]
| otherwise = gv
any_unhandled :: PatVec -> Bool
any_unhandled gv = any (not . shouldKeep) gv
shouldKeep :: Pattern -> Bool
shouldKeep p
| PmVar {} <- p = True
| PmCon {} <- p = length (allConstructors (pm_con_con p)) == 1
&& all shouldKeep (pm_con_args p)
shouldKeep (PmGrd pv e)
| all shouldKeep pv = True
| isNotPmExprOther e = True
shouldKeep _other_pat = False
cantFailPattern :: Pattern -> Bool
cantFailPattern p
| PmVar {} <- p = True
| PmCon {} <- p = length (allConstructors (pm_con_con p)) == 1
&& all cantFailPattern (pm_con_args p)
cantFailPattern (PmGrd pv _e)
= all cantFailPattern pv
cantFailPattern _ = False
translateGuard :: GuardStmt Id -> UniqSM PatVec
translateGuard (BodyStmt e _ _ _) = translateBoolGuard e
translateGuard (LetStmt binds) = translateLet (unLoc binds)
translateGuard (BindStmt p e _ _) = translateBind p e
translateGuard (LastStmt {}) = panic "translateGuard LastStmt"
translateGuard (ParStmt {}) = panic "translateGuard ParStmt"
translateGuard (TransStmt {}) = panic "translateGuard TransStmt"
translateGuard (RecStmt {}) = panic "translateGuard RecStmt"
translateGuard (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt"
translateLet :: HsLocalBinds Id -> UniqSM PatVec
translateLet _binds = return []
translateBind :: LPat Id -> LHsExpr Id -> UniqSM PatVec
translateBind (L _ p) e = do
ps <- translatePat p
return [mkGuard ps (unLoc e)]
translateBoolGuard :: LHsExpr Id -> UniqSM PatVec
translateBoolGuard e
| isJust (isTrueLHsExpr e) = return []
| otherwise = return [mkGuard [truePattern] (unLoc e)]
process_guards :: UniqSupply -> Bool -> [PatVec]
-> (ValSetAbs, ValSetAbs, ValSetAbs)
process_guards _us _oversimplify [] = (Singleton, Empty, Empty)
process_guards us oversimplify gs
| any null gs = (Singleton, Empty, Singleton)
| oversimplify = go us Singleton [[fake_pat]]
| otherwise = go us Singleton gs
where
go _usupply missing [] = (Empty, missing, Empty)
go usupply missing (gv:gvs) = (mkUnion cs css, uss, mkUnion ds dss)
where
(us1, us2, us3, us4) = splitUniqSupply4 usupply
cs = covered us1 Singleton gv missing
us = uncovered us2 Empty gv missing
ds = divergent us3 Empty gv missing
(css, uss, dss) = go us4 us gvs
pmPatType :: PmPat p -> Type
pmPatType (PmCon { pm_con_con = con, pm_con_arg_tys = tys })
= mkTyConApp (dataConTyCon con) tys
pmPatType (PmVar { pm_var_id = x }) = idType x
pmPatType (PmLit { pm_lit_lit = l }) = pmLitType l
pmPatType (PmNLit { pm_lit_id = x }) = idType x
pmPatType (PmGrd { pm_grd_pv = pv })
= ASSERT(patVecArity pv == 1) (pmPatType p)
where Just p = find ((==1) . patternArity) pv
mkOneConFull :: Id -> UniqSupply -> DataCon -> (ValAbs, [PmConstraint])
mkOneConFull x usupply con = (con_abs, constraints)
where
(usupply1, usupply2, usupply3) = splitUniqSupply3 usupply
res_ty = idType x
(univ_tvs, ex_tvs, eq_spec, thetas, arg_tys, _) = dataConFullSig con
data_tc = dataConTyCon con
tc_args = case splitTyConApp_maybe res_ty of
Just (tc, tys) -> ASSERT( tc == data_tc ) tys
Nothing -> pprPanic "mkOneConFull: Not TyConApp:" (ppr res_ty)
subst1 = zipTopTCvSubst univ_tvs tc_args
(subst, ex_tvs') = cloneTyVarBndrs subst1 ex_tvs usupply1
arguments = mkConVars usupply2 (substTys subst arg_tys)
theta_cs = substTheta subst (eqSpecPreds eq_spec ++ thetas)
evvars = zipWith (nameType "pm") (listSplitUniqSupply usupply3) theta_cs
con_abs = PmCon { pm_con_con = con
, pm_con_arg_tys = tc_args
, pm_con_tvs = ex_tvs'
, pm_con_dicts = evvars
, pm_con_args = arguments }
constraints
| null evvars = [ TmConstraint (PmExprVar x) (valAbsToPmExpr con_abs) ]
| otherwise = [ TmConstraint (PmExprVar x) (valAbsToPmExpr con_abs)
, TyConstraint evvars ]
mkConVars :: UniqSupply -> [Type] -> [ValAbs]
mkConVars us tys = zipWith mkPmVar (listSplitUniqSupply us) tys
tailVSA :: ValSetAbs -> ValSetAbs
tailVSA Empty = Empty
tailVSA Singleton = panic "tailVSA: Singleton"
tailVSA (Union vsa1 vsa2) = tailVSA vsa1 `mkUnion` tailVSA vsa2
tailVSA (Constraint cs vsa) = cs `mkConstraint` tailVSA vsa
tailVSA (Cons _ vsa) = vsa
wrapK :: DataCon -> [Type] -> [TyVar] -> [EvVar] -> ValSetAbs -> ValSetAbs
wrapK con arg_tys ex_tvs dicts = go (dataConSourceArity con) emptylist
where
go :: Int -> DList ValAbs -> ValSetAbs -> ValSetAbs
go _ _ Empty = Empty
go 0 args vsa = PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_tvs = ex_tvs, pm_con_dicts = dicts
, pm_con_args = toList args } `mkCons` vsa
go _ _ Singleton = panic "wrapK: Singleton"
go n args (Cons vs vsa) = go (n1) (args `snoc` vs) vsa
go n args (Constraint cs vsa) = cs `mkConstraint` go n args vsa
go n args (Union vsa1 vsa2) = go n args vsa1 `mkUnion` go n args vsa2
newtype DList a = DL { unDL :: [a] -> [a] }
toList :: DList a -> [a]
toList = ($[]) . unDL
emptylist :: DList a
emptylist = DL id
infixl `snoc`
snoc :: DList a -> a -> DList a
snoc xs x = DL (unDL xs . (x:))
mkConstraint :: [PmConstraint] -> ValSetAbs -> ValSetAbs
mkConstraint _cs Empty = Empty
mkConstraint cs1 (Constraint cs2 vsa) = Constraint (cs1++cs2) vsa
mkConstraint cs other_vsa = Constraint cs other_vsa
mkUnion :: ValSetAbs -> ValSetAbs -> ValSetAbs
mkUnion Empty vsa = vsa
mkUnion vsa Empty = vsa
mkUnion vsa1 vsa2 = Union vsa1 vsa2
mkCons :: ValAbs -> ValSetAbs -> ValSetAbs
mkCons _ Empty = Empty
mkCons va vsa = Cons va vsa
mkGuard :: PatVec -> HsExpr Id -> Pattern
mkGuard pv e = PmGrd pv (hsExprToPmExpr e)
mkPmVar :: UniqSupply -> Type -> PmPat p
mkPmVar usupply ty = PmVar (mkPmId usupply ty)
mkPmVarSM :: Type -> UniqSM Pattern
mkPmVarSM ty = flip mkPmVar ty <$> getUniqueSupplyM
mkPmVarsSM :: [Type] -> UniqSM PatVec
mkPmVarsSM tys = mapM mkPmVarSM tys
mkPmId :: UniqSupply -> Type -> Id
mkPmId usupply ty = mkLocalId name ty
where
unique = uniqFromSupply usupply
occname = mkVarOccFS (fsLit (show unique))
name = mkInternalName unique occname noSrcSpan
mkPmId2FormsSM :: Type -> UniqSM (Pattern, LHsExpr Id)
mkPmId2FormsSM ty = do
us <- getUniqueSupplyM
let x = mkPmId us ty
return (PmVar x, noLoc (HsVar (noLoc x)))
valAbsToPmExpr :: ValAbs -> PmExpr
valAbsToPmExpr (PmCon { pm_con_con = c, pm_con_args = ps })
= PmExprCon c (map valAbsToPmExpr ps)
valAbsToPmExpr (PmVar { pm_var_id = x }) = PmExprVar x
valAbsToPmExpr (PmLit { pm_lit_lit = l }) = PmExprLit l
valAbsToPmExpr (PmNLit { pm_lit_id = x }) = PmExprVar x
coercePatVec :: PatVec -> ValVecAbs
coercePatVec pv = concatMap coercePmPat pv
coercePmPat :: Pattern -> [ValAbs]
coercePmPat (PmVar { pm_var_id = x }) = [PmVar { pm_var_id = x }]
coercePmPat (PmLit { pm_lit_lit = l }) = [PmLit { pm_lit_lit = l }]
coercePmPat (PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_tvs = tvs, pm_con_dicts = dicts
, pm_con_args = args })
= [PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_tvs = tvs, pm_con_dicts = dicts
, pm_con_args = coercePatVec args }]
coercePmPat (PmGrd {}) = []
no_fixity :: a
no_fixity = panic "Check: no fixity"
allConstructors :: DataCon -> [DataCon]
allConstructors = tyConDataCons . dataConTyCon
newEvVar :: Name -> Type -> EvVar
newEvVar name ty = mkLocalId name (toTcType ty)
nameType :: String -> UniqSupply -> Type -> EvVar
nameType name usupply ty = newEvVar idname ty
where
unique = uniqFromSupply usupply
occname = mkVarOccFS (fsLit (name++"_"++show unique))
idname = mkInternalName unique occname noSrcSpan
splitConstraints :: [PmConstraint] -> ([EvVar], [(PmExpr, PmExpr)], Maybe Id)
splitConstraints [] = ([],[],Nothing)
splitConstraints (c : rest)
= case c of
TyConstraint cs -> (cs ++ ty_cs, tm_cs, bot_cs)
TmConstraint e1 e2 -> (ty_cs, (e1,e2):tm_cs, bot_cs)
BtConstraint cs -> ASSERT(isNothing bot_cs)
(ty_cs, tm_cs, Just cs)
where
(ty_cs, tm_cs, bot_cs) = splitConstraints rest
anySatVSA :: ValSetAbs -> PmM Bool
anySatVSA vsa = notNull <$> pruneVSABound 1 vsa
pruneVSA :: ValSetAbs -> PmM [([PmExpr], [ComplexEq])]
pruneVSA vsa = pruneVSABound (maximum_output+1) vsa
substInValVecAbs :: PmVarEnv -> ValVecAbs -> [PmExpr]
substInValVecAbs subst = map (exprDeepLookup subst . valAbsToPmExpr)
mergeBotCs :: Maybe Id -> Maybe Id -> Maybe Id
mergeBotCs (Just x) Nothing = Just x
mergeBotCs Nothing (Just y) = Just y
mergeBotCs Nothing Nothing = Nothing
mergeBotCs (Just _) (Just _) = panic "mergeBotCs: two (x ~ _|_) constraints"
wrapUpTmState :: TmState -> ([ComplexEq], PmVarEnv)
wrapUpTmState (residual, (_, subst)) = (residual, flattenPmVarEnv subst)
pruneVSABound :: Int -> ValSetAbs -> PmM [([PmExpr], [ComplexEq])]
pruneVSABound n v = go n init_cs emptylist v
where
init_cs :: ([EvVar], TmState, Maybe Id)
init_cs = ([], initialTmState, Nothing)
go :: Int -> ([EvVar], TmState, Maybe Id) -> DList ValAbs
-> ValSetAbs -> PmM [([PmExpr], [ComplexEq])]
go n all_cs@(ty_cs, tm_env, bot_ct) vec in_vsa
| n <= 0 = return []
| otherwise = case in_vsa of
Empty -> return []
Union vsa1 vsa2 -> do
vecs1 <- go n all_cs vec vsa1
vecs2 <- go (n length vecs1) all_cs vec vsa2
return (vecs1 ++ vecs2)
Singleton -> do
sat <- tyOracle (listToBag ty_cs)
return $ case sat of
True -> let (residual_eqs, subst) = wrapUpTmState tm_env
vector = substInValVecAbs subst (toList vec)
in [(vector, residual_eqs)]
False -> []
Constraint cs vsa -> case splitConstraints cs of
(new_ty_cs, new_tm_cs, new_bot_ct) ->
case tmOracle tm_env new_tm_cs of
Just new_tm_env ->
let bot = mergeBotCs new_bot_ct bot_ct
ans = case bot of
Nothing -> True
Just b -> canDiverge b new_tm_env
in case ans of
True -> go n (new_ty_cs++ty_cs,new_tm_env,bot) vec vsa
False -> return []
Nothing -> return []
Cons va vsa -> go n all_cs (vec `snoc` va) vsa
maximum_output :: Int
maximum_output = 4
tyOracle :: Bag EvVar -> PmM Bool
tyOracle evs
= do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs
; case res of
Just sat -> return sat
Nothing -> pprPanic "tyOracle" (vcat $ pprErrMsgBagWithLoc errs) }
type PmArity = Int
patVecArity :: PatVec -> PmArity
patVecArity = sum . map patternArity
patternArity :: Pattern -> PmArity
patternArity (PmGrd {}) = 0
patternArity _other_pat = 1
patVectProc :: Bool -> (PatVec, [PatVec]) -> ValSetAbs
-> PmM (Bool, Bool, ValSetAbs)
patVectProc oversimplify (vec,gvs) vsa = do
us <- getUniqueSupplyM
let (c_def, u_def, d_def) = process_guards us oversimplify gvs
(usC, usU, usD) <- getUniqueSupplyM3
mb_c <- anySatVSA (covered usC c_def vec vsa)
mb_d <- anySatVSA (divergent usD d_def vec vsa)
let vsa' = uncovered usU u_def vec vsa
return (mb_c, mb_d, vsa')
covered, uncovered, divergent :: UniqSupply -> ValSetAbs
-> PatVec -> ValSetAbs -> ValSetAbs
covered us gvsa vec vsa = pmTraverse us gvsa cMatcher vec vsa
uncovered us gvsa vec vsa = pmTraverse us gvsa uMatcher vec vsa
divergent us gvsa vec vsa = pmTraverse us gvsa dMatcher vec vsa
pmTraverse :: UniqSupply
-> ValSetAbs
-> PmMatcher
-> PatVec
-> ValSetAbs
-> ValSetAbs
pmTraverse _us _gvsa _rec _vec Empty = Empty
pmTraverse _us gvsa _rec [] Singleton = gvsa
pmTraverse _us _gvsa _rec [] (Cons _ _) = panic "pmTraverse: cons"
pmTraverse us gvsa rec vec (Union vsa1 vsa2)
= mkUnion (pmTraverse us1 gvsa rec vec vsa1)
(pmTraverse us2 gvsa rec vec vsa2)
where (us1, us2) = splitUniqSupply us
pmTraverse us gvsa rec vec (Constraint cs vsa)
= mkConstraint cs (pmTraverse us gvsa rec vec vsa)
pmTraverse us gvsa rec (p:ps) vsa
| PmGrd pv e <- p
=
let (us1, us2) = splitUniqSupply us
y = mkPmId us1 (pmPatType p)
cs = [TmConstraint (PmExprVar y) e]
in mkConstraint cs $ tailVSA $
pmTraverse us2 gvsa rec (pv++ps) (PmVar y `mkCons` vsa)
| Cons va vsa <- vsa = rec us gvsa p ps va vsa
| otherwise = panic "pmTraverse: singleton"
type PmMatcher = UniqSupply
-> ValSetAbs
-> Pattern -> PatVec
-> ValAbs -> ValSetAbs
-> ValSetAbs
cMatcher, uMatcher, dMatcher :: PmMatcher
cMatcher us gvsa (PmVar x) ps va vsa
= va `mkCons` (cs `mkConstraint` covered us gvsa ps vsa)
where cs = [TmConstraint (PmExprVar x) (valAbsToPmExpr va)]
cMatcher us gvsa (PmLit l) ps (va@(PmCon {})) vsa
= va `mkCons` (cs `mkConstraint` covered us gvsa ps vsa)
where cs = [ TmConstraint (PmExprLit l) (valAbsToPmExpr va) ]
cMatcher us gvsa (p@(PmCon {})) ps (PmLit l) vsa
= cMatcher us3 gvsa p ps con_abs (mkConstraint cs vsa)
where
(us1, us2, us3) = splitUniqSupply3 us
y = mkPmId us1 (pmPatType p)
(con_abs, all_cs) = mkOneConFull y us2 (pm_con_con p)
cs = TmConstraint (PmExprVar y) (PmExprLit l) : all_cs
cMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps
(PmNLit { pm_lit_id = x }) vsa
= cMatcher us2 gvsa p ps con_abs (mkConstraint all_cs vsa)
where
(us1, us2) = splitUniqSupply us
(con_abs, all_cs) = mkOneConFull x us1 con
cMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
(PmCon { pm_con_con = c2, pm_con_args = args2 }) vsa
| c1 /= c2 = Empty
| otherwise = wrapK c1 (pm_con_arg_tys p)
(pm_con_tvs p)
(pm_con_dicts p)
(covered us gvsa (args1 ++ ps)
(foldr mkCons vsa args2))
cMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
True -> va `mkCons` covered us gvsa ps vsa
False -> Empty
cMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps (PmVar x) vsa
= cMatcher us2 gvsa p ps con_abs (mkConstraint all_cs vsa)
where
(us1, us2) = splitUniqSupply us
(con_abs, all_cs) = mkOneConFull x us1 con
cMatcher us gvsa (p@(PmLit l)) ps (PmVar x) vsa
= cMatcher us gvsa p ps lit_abs (mkConstraint cs vsa)
where
lit_abs = PmLit l
cs = [TmConstraint (PmExprVar x) (PmExprLit l)]
cMatcher us gvsa (p@(PmLit l)) ps
(PmNLit { pm_lit_id = x, pm_lit_not = lits }) vsa
| all (not . eqPmLit l) lits
= cMatcher us gvsa p ps lit_abs (mkConstraint cs vsa)
| otherwise = Empty
where
lit_abs = PmLit l
cs = [TmConstraint (PmExprVar x) (PmExprLit l)]
cMatcher _ _ (PmGrd {}) _ _ _ = panic "Check.cMatcher: Guard"
uMatcher us gvsa (PmVar x) ps va vsa
= va `mkCons` (cs `mkConstraint` uncovered us gvsa ps vsa)
where cs = [TmConstraint (PmExprVar x) (valAbsToPmExpr va)]
uMatcher us gvsa (PmLit l) ps (va@(PmCon {})) vsa
= uMatcher us2 gvsa (PmVar y) ps va (mkConstraint cs vsa)
where
(us1, us2) = splitUniqSupply us
y = mkPmId us1 (pmPatType va)
cs = [TmConstraint (PmExprVar y) (PmExprLit l)]
uMatcher us gvsa (p@(PmCon {})) ps (PmLit l) vsa
= uMatcher us2 gvsa p ps (PmVar y) (mkConstraint cs vsa)
where
(us1, us2) = splitUniqSupply us
y = mkPmId us1 (pmPatType p)
cs = [TmConstraint (PmExprVar y) (PmExprLit l)]
uMatcher us gvsa (p@(PmCon {})) ps (PmNLit { pm_lit_id = x }) vsa
= uMatcher us gvsa p ps (PmVar x) vsa
uMatcher us gvsa ( p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
(va@(PmCon { pm_con_con = c2, pm_con_args = args2 })) vsa
| c1 /= c2 = va `mkCons` vsa
| otherwise = wrapK c1 (pm_con_arg_tys p)
(pm_con_tvs p)
(pm_con_dicts p)
(uncovered us gvsa (args1 ++ ps)
(foldr mkCons vsa args2))
uMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
True -> va `mkCons` uncovered us gvsa ps vsa
False -> va `mkCons` vsa
uMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps (PmVar x) vsa
= uncovered us2 gvsa (p : ps) inst_vsa
where
(us1, us2) = splitUniqSupply us
cons_cs = zipWith (mkOneConFull x) (listSplitUniqSupply us1)
(allConstructors con)
add_one (va,cs) valset = mkUnion valset (va `mkCons` mkConstraint cs vsa)
inst_vsa = foldr add_one Empty cons_cs
uMatcher us gvsa (p@(PmLit l)) ps (PmVar x) vsa
= mkUnion (uMatcher us gvsa p ps (PmLit l) (mkConstraint match_cs vsa))
(non_match_cs `mkConstraint` (PmNLit x [l] `mkCons` vsa))
where
match_cs = [ TmConstraint (PmExprVar x) (PmExprLit l)]
non_match_cs = [ TmConstraint falsePmExpr
(PmExprEq (PmExprVar x) (PmExprLit l)) ]
uMatcher us gvsa (p@(PmLit l)) ps
(va@(PmNLit { pm_lit_id = x, pm_lit_not = lits })) vsa
| all (not . eqPmLit l) lits
= mkUnion (uMatcher us gvsa p ps (PmLit l) (mkConstraint match_cs vsa))
(non_match_cs `mkConstraint` (PmNLit x (l:lits) `mkCons` vsa))
| otherwise = va `mkCons` vsa
where
match_cs = [ TmConstraint (PmExprVar x) (PmExprLit l)]
non_match_cs = [ TmConstraint falsePmExpr
(PmExprEq (PmExprVar x) (PmExprLit l)) ]
uMatcher _ _ (PmGrd {}) _ _ _ = panic "Check.uMatcher: Guard"
dMatcher us gvsa (PmVar x) ps va vsa
= va `mkCons` (cs `mkConstraint` divergent us gvsa ps vsa)
where cs = [TmConstraint (PmExprVar x) (valAbsToPmExpr va)]
dMatcher us gvsa (PmLit l) ps (va@(PmCon {})) vsa
= va `mkCons` (cs `mkConstraint` divergent us gvsa ps vsa)
where cs = [ TmConstraint (PmExprLit l) (valAbsToPmExpr va) ]
dMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps (PmLit l) vsa
= dMatcher us3 gvsa p ps con_abs (mkConstraint cs vsa)
where
(us1, us2, us3) = splitUniqSupply3 us
y = mkPmId us1 (pmPatType p)
(con_abs, all_cs) = mkOneConFull y us2 con
cs = TmConstraint (PmExprVar y) (PmExprLit l) : all_cs
dMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps
(PmNLit { pm_lit_id = x }) vsa
= dMatcher us2 gvsa p ps con_abs (mkConstraint all_cs vsa)
where
(us1, us2) = splitUniqSupply us
(con_abs, all_cs) = mkOneConFull x us1 con
dMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
(PmCon { pm_con_con = c2, pm_con_args = args2 }) vsa
| c1 /= c2 = Empty
| otherwise = wrapK c1 (pm_con_arg_tys p)
(pm_con_tvs p)
(pm_con_dicts p)
(divergent us gvsa (args1 ++ ps)
(foldr mkCons vsa args2))
dMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
True -> va `mkCons` divergent us gvsa ps vsa
False -> Empty
dMatcher us gvsa (p@(PmCon { pm_con_con = con })) ps (PmVar x) vsa
= mkUnion (PmVar x `mkCons` mkConstraint [BtConstraint x] vsa)
(dMatcher us2 gvsa p ps con_abs (mkConstraint all_cs vsa))
where
(us1, us2) = splitUniqSupply us
(con_abs, all_cs) = mkOneConFull x us1 con
dMatcher us gvsa (PmLit l) ps (PmVar x) vsa
= mkUnion (PmVar x `mkCons` mkConstraint [BtConstraint x] vsa)
(dMatcher us gvsa (PmLit l) ps (PmLit l) (mkConstraint cs vsa))
where
cs = [TmConstraint (PmExprVar x) (PmExprLit l)]
dMatcher us gvsa (p@(PmLit l)) ps
(PmNLit { pm_lit_id = x, pm_lit_not = lits }) vsa
| all (not . eqPmLit l) lits
= dMatcher us gvsa p ps lit_abs (mkConstraint cs vsa)
| otherwise = Empty
where
lit_abs = PmLit l
cs = [TmConstraint (PmExprVar x) (PmExprLit l)]
dMatcher _ _ (PmGrd {}) _ _ _ = panic "Check.dMatcher: Guard"
genCaseTmCs2 :: Maybe (LHsExpr Id)
-> [Pat Id]
-> [Id]
-> DsM (Bag SimpleEq)
genCaseTmCs2 Nothing _ _ = return emptyBag
genCaseTmCs2 (Just scr) [p] [var] = liftUs $ do
[e] <- map valAbsToPmExpr . coercePatVec <$> translatePat p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
genCaseTmCs2 _ _ _ = panic "genCaseTmCs2: HsCase"
genCaseTmCs1 :: Maybe (LHsExpr Id) -> [Id] -> Bag SimpleEq
genCaseTmCs1 Nothing _ = emptyBag
genCaseTmCs1 (Just scr) [var] = unitBag (var, lhsExprToPmExpr scr)
genCaseTmCs1 _ _ = panic "genCaseTmCs1: HsCase"
isAnyPmCheckEnabled :: DynFlags -> DsMatchContext -> Bool
isAnyPmCheckEnabled dflags (DsMatchContext kind _loc)
= wopt Opt_WarnOverlappingPatterns dflags || exhaustive dflags kind
warnManyGuards :: DsMatchContext -> DsM ()
warnManyGuards (DsMatchContext kind loc)
= putSrcSpanDs loc $ warnDs $ vcat
[ sep [ ptext (sLit "Too many guards in") <+> pprMatchContext kind
, ptext (sLit "Guard checking has been over-simplified") ]
, parens (ptext (sLit "Use:") <+> (opt_1 $$ opt_2)) ]
where
opt_1 = hang (ptext (sLit "-Wno-too-many-guards")) 2 $
ptext (sLit "to suppress this warning")
opt_2 = hang (ptext (sLit "-ffull-guard-reasoning")) 2 $ vcat
[ ptext (sLit "to run the full checker (may increase")
, ptext (sLit "compilation time and memory consumption)") ]
dsPmWarn :: DynFlags -> DsMatchContext -> DsM PmResult -> DsM ()
dsPmWarn dflags ctx@(DsMatchContext kind loc) mPmResult
= when (flag_i || flag_u) $ do
(redundant, inaccessible, uncovered) <- mPmResult
let exists_r = flag_i && notNull redundant
exists_i = flag_i && notNull inaccessible
exists_u = flag_u && notNull uncovered
when exists_r $ putSrcSpanDs loc (warnDs (pprEqns redundant rmsg))
when exists_i $ putSrcSpanDs loc (warnDs (pprEqns inaccessible imsg))
when exists_u $ putSrcSpanDs loc (warnDs (pprEqnsU uncovered))
where
flag_i = wopt Opt_WarnOverlappingPatterns dflags
flag_u = exhaustive dflags kind
rmsg = "are redundant"
imsg = "have inaccessible right hand side"
pprEqns qs text = pp_context ctx (ptext (sLit text)) $ \f ->
vcat (map (ppr_eqn f kind) (take maximum_output qs)) $$ dots qs
pprEqnsU qs = pp_context ctx (ptext (sLit "are non-exhaustive")) $ \_ ->
case qs of
[([],_)] -> ptext (sLit "Guards do not cover entire pattern space")
_missing -> let us = map ppr_uncovered qs
in hang (ptext (sLit "Patterns not matched:")) 4
(vcat (take maximum_output us) $$ dots us)
dots :: [a] -> SDoc
dots qs | qs `lengthExceeds` maximum_output = ptext (sLit "...")
| otherwise = empty
exhaustive :: DynFlags -> HsMatchContext id -> Bool
exhaustive dflags (FunRhs {}) = wopt Opt_WarnIncompletePatterns dflags
exhaustive dflags CaseAlt = wopt Opt_WarnIncompletePatterns dflags
exhaustive _dflags IfAlt = False
exhaustive dflags LambdaExpr = wopt Opt_WarnIncompleteUniPatterns dflags
exhaustive dflags PatBindRhs = wopt Opt_WarnIncompleteUniPatterns dflags
exhaustive dflags ProcExpr = wopt Opt_WarnIncompleteUniPatterns dflags
exhaustive dflags RecUpd = wopt Opt_WarnIncompletePatternsRecUpd dflags
exhaustive _dflags ThPatSplice = False
exhaustive _dflags PatSyn = False
exhaustive _dflags ThPatQuote = False
exhaustive _dflags (StmtCtxt {}) = False
pp_context :: DsMatchContext -> SDoc -> ((SDoc -> SDoc) -> SDoc) -> SDoc
pp_context (DsMatchContext kind _loc) msg rest_of_msg_fun
= vcat [ptext (sLit "Pattern match(es)") <+> msg,
sep [ ptext (sLit "In") <+> ppr_match <> char ':'
, nest 4 (rest_of_msg_fun pref)]]
where
(ppr_match, pref)
= case kind of
FunRhs fun -> (pprMatchContext kind, \ pp -> ppr fun <+> pp)
_ -> (pprMatchContext kind, \ pp -> pp)
ppr_pats :: HsMatchContext Name -> [Pat Id] -> SDoc
ppr_pats kind pats
= sep [sep (map ppr pats), matchSeparator kind, ptext (sLit "...")]
ppr_eqn :: (SDoc -> SDoc) -> HsMatchContext Name -> [LPat Id] -> SDoc
ppr_eqn prefixF kind eqn = prefixF (ppr_pats kind (map unLoc eqn))
ppr_constraint :: (SDoc,[PmLit]) -> SDoc
ppr_constraint (var, lits) = var <+> ptext (sLit "is not one of")
<+> braces (pprWithCommas ppr lits)
ppr_uncovered :: ([PmExpr], [ComplexEq]) -> SDoc
ppr_uncovered (expr_vec, complex)
| null cs = fsep vec
| otherwise = hang (fsep vec) 4 $
ptext (sLit "where") <+> vcat (map ppr_constraint cs)
where
sdoc_vec = mapM pprPmExprWithParens expr_vec
(vec,cs) = runPmPprM sdoc_vec (filterComplex complex)