{-# LANGUAGE TypeFamilies #-}
module TcDerivUtils (
DerivM, DerivEnv(..),
DerivSpec(..), pprDerivSpec,
DerivSpecMechanism(..), isDerivSpecStock,
isDerivSpecNewtype, isDerivSpecAnyClass,
DerivContext, DerivStatus(..),
PredOrigin(..), ThetaOrigin(..), mkPredOrigin,
mkThetaOrigin, mkThetaOriginFromPreds, substPredOrigin,
checkSideConditions, hasStockDeriving,
canDeriveAnyClass,
std_class_via_coercible, non_coercible_class,
newDerivClsInst, extendLocalInstEnv
) where
import GhcPrelude
import Bag
import BasicTypes
import Class
import DataCon
import DynFlags
import ErrUtils
import HscTypes (lookupFixity, mi_fix)
import HsSyn
import Inst
import InstEnv
import LoadIface (loadInterfaceForName)
import Module (getModule)
import Name
import Outputable
import PrelNames
import SrcLoc
import TcGenDeriv
import TcGenFunctor
import TcGenGenerics
import TcRnMonad
import TcType
import THNames (liftClassKey)
import TyCon
import Type
import Util
import VarSet
import Control.Monad.Trans.Reader
import qualified GHC.LanguageExtensions as LangExt
import ListSetOps (assocMaybe)
type DerivM = ReaderT DerivEnv TcRn
data DerivEnv = DerivEnv
{ denv_overlap_mode :: Maybe OverlapMode
, denv_tvs :: [TyVar]
, denv_cls :: Class
, denv_cls_tys :: [Type]
, denv_tc :: TyCon
, denv_tc_args :: [Type]
, denv_rep_tc :: TyCon
, denv_rep_tc_args :: [Type]
, denv_mtheta :: DerivContext
, denv_strat :: Maybe DerivStrategy
}
instance Outputable DerivEnv where
ppr (DerivEnv { denv_overlap_mode = overlap_mode
, denv_tvs = tvs
, denv_cls = cls
, denv_cls_tys = cls_tys
, denv_tc = tc
, denv_tc_args = tc_args
, denv_rep_tc = rep_tc
, denv_rep_tc_args = rep_tc_args
, denv_mtheta = mtheta
, denv_strat = mb_strat })
= hang (text "DerivEnv")
2 (vcat [ text "denv_overlap_mode" <+> ppr overlap_mode
, text "denv_tvs" <+> ppr tvs
, text "denv_cls" <+> ppr cls
, text "denv_cls_tys" <+> ppr cls_tys
, text "denv_tc" <+> ppr tc
, text "denv_tc_args" <+> ppr tc_args
, text "denv_rep_tc" <+> ppr rep_tc
, text "denv_rep_tc_args" <+> ppr rep_tc_args
, text "denv_mtheta" <+> ppr mtheta
, text "denv_strat" <+> ppr mb_strat ])
data DerivSpec theta = DS { ds_loc :: SrcSpan
, ds_name :: Name
, ds_tvs :: [TyVar]
, ds_theta :: theta
, ds_cls :: Class
, ds_tys :: [Type]
, ds_tc :: TyCon
, ds_overlap :: Maybe OverlapMode
, ds_mechanism :: DerivSpecMechanism }
pprDerivSpec :: Outputable theta => DerivSpec theta -> SDoc
pprDerivSpec (DS { ds_loc = l, ds_name = n, ds_tvs = tvs, ds_cls = c,
ds_tys = tys, ds_theta = rhs, ds_mechanism = mech })
= hang (text "DerivSpec")
2 (vcat [ text "ds_loc =" <+> ppr l
, text "ds_name =" <+> ppr n
, text "ds_tvs =" <+> ppr tvs
, text "ds_cls =" <+> ppr c
, text "ds_tys =" <+> ppr tys
, text "ds_theta =" <+> ppr rhs
, text "ds_mechanism =" <+> ppr mech ])
instance Outputable theta => Outputable (DerivSpec theta) where
ppr = pprDerivSpec
data DerivSpecMechanism
= DerivSpecStock
(SrcSpan -> TyCon
-> [Type]
-> TcM (LHsBinds GhcPs, BagDerivStuff, [Name]))
| DerivSpecNewtype
Type
| DerivSpecAnyClass
isDerivSpecStock, isDerivSpecNewtype, isDerivSpecAnyClass
:: DerivSpecMechanism -> Bool
isDerivSpecStock (DerivSpecStock{}) = True
isDerivSpecStock _ = False
isDerivSpecNewtype (DerivSpecNewtype{}) = True
isDerivSpecNewtype _ = False
isDerivSpecAnyClass (DerivSpecAnyClass{}) = True
isDerivSpecAnyClass _ = False
mechanismToStrategy :: DerivSpecMechanism -> DerivStrategy
mechanismToStrategy (DerivSpecStock{}) = StockStrategy
mechanismToStrategy (DerivSpecNewtype{}) = NewtypeStrategy
mechanismToStrategy (DerivSpecAnyClass{}) = AnyclassStrategy
instance Outputable DerivSpecMechanism where
ppr = ppr . mechanismToStrategy
type DerivContext = Maybe ThetaType
data DerivStatus = CanDerive
| DerivableClassError SDoc
| DerivableViaInstance
| NonDerivableClass SDoc
data PredOrigin = PredOrigin PredType CtOrigin TypeOrKind
data ThetaOrigin
= ThetaOrigin { to_tvs :: [TyVar]
, to_givens :: ThetaType
, to_wanted_origins :: [PredOrigin] }
instance Outputable PredOrigin where
ppr (PredOrigin ty _ _) = ppr ty
instance Outputable ThetaOrigin where
ppr (ThetaOrigin { to_tvs = tvs
, to_givens = givens
, to_wanted_origins = wanted_origins })
= hang (text "ThetaOrigin")
2 (vcat [ text "to_tvs =" <+> ppr tvs
, text "to_givens =" <+> ppr givens
, text "to_wanted_origins =" <+> ppr wanted_origins ])
mkPredOrigin :: CtOrigin -> TypeOrKind -> PredType -> PredOrigin
mkPredOrigin origin t_or_k pred = PredOrigin pred origin t_or_k
mkThetaOrigin :: CtOrigin -> TypeOrKind -> [TyVar] -> ThetaType -> ThetaType
-> ThetaOrigin
mkThetaOrigin origin t_or_k tvs givens
= ThetaOrigin tvs givens . map (mkPredOrigin origin t_or_k)
mkThetaOriginFromPreds :: [PredOrigin] -> ThetaOrigin
mkThetaOriginFromPreds = ThetaOrigin [] []
substPredOrigin :: HasCallStack => TCvSubst -> PredOrigin -> PredOrigin
substPredOrigin subst (PredOrigin pred origin t_or_k)
= PredOrigin (substTy subst pred) origin t_or_k
hasStockDeriving
:: Class -> Maybe (SrcSpan
-> TyCon
-> [Type]
-> TcM (LHsBinds GhcPs, BagDerivStuff, [Name]))
hasStockDeriving clas
= assocMaybe gen_list (getUnique clas)
where
gen_list
:: [(Unique, SrcSpan
-> TyCon
-> [Type]
-> TcM (LHsBinds GhcPs, BagDerivStuff, [Name]))]
gen_list = [ (eqClassKey, simpleM gen_Eq_binds)
, (ordClassKey, simpleM gen_Ord_binds)
, (enumClassKey, simpleM gen_Enum_binds)
, (boundedClassKey, simple gen_Bounded_binds)
, (ixClassKey, simpleM gen_Ix_binds)
, (showClassKey, read_or_show gen_Show_binds)
, (readClassKey, read_or_show gen_Read_binds)
, (dataClassKey, simpleM gen_Data_binds)
, (functorClassKey, simple gen_Functor_binds)
, (foldableClassKey, simple gen_Foldable_binds)
, (traversableClassKey, simple gen_Traversable_binds)
, (liftClassKey, simple gen_Lift_binds)
, (genClassKey, generic (gen_Generic_binds Gen0))
, (gen1ClassKey, generic (gen_Generic_binds Gen1)) ]
simple gen_fn loc tc _
= let (binds, deriv_stuff) = gen_fn loc tc
in return (binds, deriv_stuff, [])
simpleM gen_fn loc tc _
= do { (binds, deriv_stuff) <- gen_fn loc tc
; return (binds, deriv_stuff, []) }
read_or_show gen_fn loc tc _
= do { fix_env <- getDataConFixityFun tc
; let (binds, deriv_stuff) = gen_fn fix_env loc tc
field_names = all_field_names tc
; return (binds, deriv_stuff, field_names) }
generic gen_fn _ tc inst_tys
= do { (binds, faminst) <- gen_fn tc inst_tys
; let field_names = all_field_names tc
; return (binds, unitBag (DerivFamInst faminst), field_names) }
all_field_names = map flSelector . concatMap dataConFieldLabels
. tyConDataCons
getDataConFixityFun :: TyCon -> TcM (Name -> Fixity)
getDataConFixityFun tc
= do { this_mod <- getModule
; if nameIsLocalOrFrom this_mod name
then do { fix_env <- getFixityEnv
; return (lookupFixity fix_env) }
else do { iface <- loadInterfaceForName doc name
; return (mi_fix iface . nameOccName) } }
where
name = tyConName tc
doc = text "Data con fixities for" <+> ppr name
checkSideConditions :: DynFlags -> DerivContext -> Class -> [TcType]
-> TyCon
-> DerivStatus
checkSideConditions dflags mtheta cls cls_tys rep_tc
| Just cond <- sideConditions mtheta cls
= case (cond dflags rep_tc) of
NotValid err -> DerivableClassError err
IsValid | null (filterOutInvisibleTypes (classTyCon cls) cls_tys)
-> CanDerive
| otherwise -> DerivableClassError (classArgsErr cls cls_tys)
| NotValid err <- canDeriveAnyClass dflags
= NonDerivableClass err
| otherwise
= DerivableViaInstance
classArgsErr :: Class -> [Type] -> SDoc
classArgsErr cls cls_tys = quotes (ppr (mkClassPred cls cls_tys)) <+> text "is not a class"
sideConditions :: DerivContext -> Class -> Maybe Condition
sideConditions mtheta cls
| cls_key == eqClassKey = Just (cond_std `andCond` cond_args cls)
| cls_key == ordClassKey = Just (cond_std `andCond` cond_args cls)
| cls_key == showClassKey = Just (cond_std `andCond` cond_args cls)
| cls_key == readClassKey = Just (cond_std `andCond` cond_args cls)
| cls_key == enumClassKey = Just (cond_std `andCond` cond_isEnumeration)
| cls_key == ixClassKey = Just (cond_std `andCond` cond_enumOrProduct cls)
| cls_key == boundedClassKey = Just (cond_std `andCond` cond_enumOrProduct cls)
| cls_key == dataClassKey = Just (checkFlag LangExt.DeriveDataTypeable `andCond`
cond_vanilla `andCond`
cond_args cls)
| cls_key == functorClassKey = Just (checkFlag LangExt.DeriveFunctor `andCond`
cond_vanilla `andCond`
cond_functorOK True False)
| cls_key == foldableClassKey = Just (checkFlag LangExt.DeriveFoldable `andCond`
cond_vanilla `andCond`
cond_functorOK False True)
| cls_key == traversableClassKey = Just (checkFlag LangExt.DeriveTraversable `andCond`
cond_vanilla `andCond`
cond_functorOK False False)
| cls_key == genClassKey = Just (checkFlag LangExt.DeriveGeneric `andCond`
cond_vanilla `andCond`
cond_RepresentableOk)
| cls_key == gen1ClassKey = Just (checkFlag LangExt.DeriveGeneric `andCond`
cond_vanilla `andCond`
cond_Representable1Ok)
| cls_key == liftClassKey = Just (checkFlag LangExt.DeriveLift `andCond`
cond_vanilla `andCond`
cond_args cls)
| otherwise = Nothing
where
cls_key = getUnique cls
cond_std = cond_stdOK mtheta False
cond_vanilla = cond_stdOK mtheta True
canDeriveAnyClass :: DynFlags -> Validity
canDeriveAnyClass dflags
| not (xopt LangExt.DeriveAnyClass dflags)
= NotValid (text "Try enabling DeriveAnyClass")
| otherwise
= IsValid
type Condition = DynFlags -> TyCon -> Validity
orCond :: Condition -> Condition -> Condition
orCond c1 c2 dflags tc
= case (c1 dflags tc, c2 dflags tc) of
(IsValid, _) -> IsValid
(_, IsValid) -> IsValid
(NotValid x, NotValid y) -> NotValid (x $$ text " or" $$ y)
andCond :: Condition -> Condition -> Condition
andCond c1 c2 dflags tc = c1 dflags tc `andValid` c2 dflags tc
cond_stdOK :: DerivContext
-> Bool
-> Condition
cond_stdOK (Just _) _ _ _
= IsValid
cond_stdOK Nothing permissive dflags rep_tc
| null data_cons
, not permissive = checkFlag LangExt.EmptyDataDeriving dflags rep_tc
`orValid`
NotValid (no_cons_why rep_tc $$ empty_data_suggestion)
| not (null con_whys) = NotValid (vcat con_whys $$ standalone_suggestion)
| otherwise = IsValid
where
empty_data_suggestion =
text "Use EmptyDataDeriving to enable deriving for empty data types"
standalone_suggestion =
text "Possible fix: use a standalone deriving declaration instead"
data_cons = tyConDataCons rep_tc
con_whys = getInvalids (map check_con data_cons)
check_con :: DataCon -> Validity
check_con con
| not (null eq_spec)
= bad "is a GADT"
| not (null ex_tvs)
= bad "has existential type variables in its type"
| not (null theta)
= bad "has constraints in its type"
| not (permissive || all isTauTy (dataConOrigArgTys con))
= bad "has a higher-rank type"
| otherwise
= IsValid
where
(_, ex_tvs, eq_spec, theta, _, _) = dataConFullSig con
bad msg = NotValid (badCon con (text msg))
no_cons_why :: TyCon -> SDoc
no_cons_why rep_tc = quotes (pprSourceTyCon rep_tc) <+>
text "must have at least one data constructor"
cond_RepresentableOk :: Condition
cond_RepresentableOk _ tc = canDoGenerics tc
cond_Representable1Ok :: Condition
cond_Representable1Ok _ tc = canDoGenerics1 tc
cond_enumOrProduct :: Class -> Condition
cond_enumOrProduct cls = cond_isEnumeration `orCond`
(cond_isProduct `andCond` cond_args cls)
cond_args :: Class -> Condition
cond_args cls _ tc
= case bad_args of
[] -> IsValid
(ty:_) -> NotValid (hang (text "Don't know how to derive" <+> quotes (ppr cls))
2 (text "for type" <+> quotes (ppr ty)))
where
bad_args = [ arg_ty | con <- tyConDataCons tc
, arg_ty <- dataConOrigArgTys con
, isUnliftedType arg_ty
, not (ok_ty arg_ty) ]
cls_key = classKey cls
ok_ty arg_ty
| cls_key == eqClassKey = check_in arg_ty ordOpTbl
| cls_key == ordClassKey = check_in arg_ty ordOpTbl
| cls_key == showClassKey = check_in arg_ty boxConTbl
| cls_key == liftClassKey = check_in arg_ty litConTbl
| otherwise = False
check_in :: Type -> [(Type,a)] -> Bool
check_in arg_ty tbl = any (eqType arg_ty . fst) tbl
cond_isEnumeration :: Condition
cond_isEnumeration _ rep_tc
| isEnumerationTyCon rep_tc = IsValid
| otherwise = NotValid why
where
why = sep [ quotes (pprSourceTyCon rep_tc) <+>
text "must be an enumeration type"
, text "(an enumeration consists of one or more nullary, non-GADT constructors)" ]
cond_isProduct :: Condition
cond_isProduct _ rep_tc
| isProductTyCon rep_tc = IsValid
| otherwise = NotValid why
where
why = quotes (pprSourceTyCon rep_tc) <+>
text "must have precisely one constructor"
cond_functorOK :: Bool -> Bool -> Condition
cond_functorOK allowFunctions allowExQuantifiedLastTyVar _ rep_tc
| null tc_tvs
= NotValid (text "Data type" <+> quotes (ppr rep_tc)
<+> text "must have some type parameters")
| not (null bad_stupid_theta)
= NotValid (text "Data type" <+> quotes (ppr rep_tc)
<+> text "must not have a class context:" <+> pprTheta bad_stupid_theta)
| otherwise
= allValid (map check_con data_cons)
where
tc_tvs = tyConTyVars rep_tc
Just (_, last_tv) = snocView tc_tvs
bad_stupid_theta = filter is_bad (tyConStupidTheta rep_tc)
is_bad pred = last_tv `elemVarSet` exactTyCoVarsOfType pred
data_cons = tyConDataCons rep_tc
check_con con = allValid (check_universal con : foldDataConArgs (ft_check con) con)
check_universal :: DataCon -> Validity
check_universal con
| allowExQuantifiedLastTyVar
= IsValid
| Just tv <- getTyVar_maybe (last (tyConAppArgs (dataConOrigResTy con)))
, tv `elem` dataConUnivTyVars con
, not (tv `elemVarSet` exactTyCoVarsOfTypes (dataConTheta con))
= IsValid
| otherwise
= NotValid (badCon con existential)
ft_check :: DataCon -> FFoldType Validity
ft_check con = FT { ft_triv = IsValid, ft_var = IsValid
, ft_co_var = NotValid (badCon con covariant)
, ft_fun = \x y -> if allowFunctions then x `andValid` y
else NotValid (badCon con functions)
, ft_tup = \_ xs -> allValid xs
, ft_ty_app = \_ x -> x
, ft_bad_app = NotValid (badCon con wrong_arg)
, ft_forall = \_ x -> x }
existential = text "must be truly polymorphic in the last argument of the data type"
covariant = text "must not use the type variable in a function argument"
functions = text "must not contain function types"
wrong_arg = text "must use the type variable only as the last argument of a data type"
checkFlag :: LangExt.Extension -> Condition
checkFlag flag dflags _
| xopt flag dflags = IsValid
| otherwise = NotValid why
where
why = text "You need " <> text flag_str
<+> text "to derive an instance for this class"
flag_str = case [ flagSpecName f | f <- xFlags , flagSpecFlag f == flag ] of
[s] -> s
other -> pprPanic "checkFlag" (ppr other)
std_class_via_coercible :: Class -> Bool
std_class_via_coercible clas
= classKey clas `elem` [eqClassKey, ordClassKey, ixClassKey, boundedClassKey]
non_coercible_class :: Class -> Bool
non_coercible_class cls
= classKey cls `elem` ([ readClassKey, showClassKey, dataClassKey
, genClassKey, gen1ClassKey, typeableClassKey
, traversableClassKey, liftClassKey ])
badCon :: DataCon -> SDoc -> SDoc
badCon con msg = text "Constructor" <+> quotes (ppr con) <+> msg
newDerivClsInst :: ThetaType -> DerivSpec theta -> TcM ClsInst
newDerivClsInst theta (DS { ds_name = dfun_name, ds_overlap = overlap_mode
, ds_tvs = tvs, ds_cls = clas, ds_tys = tys })
= newClsInst overlap_mode dfun_name tvs theta clas tys
extendLocalInstEnv :: [ClsInst] -> TcM a -> TcM a
extendLocalInstEnv dfuns thing_inside
= do { env <- getGblEnv
; let inst_env' = extendInstEnvList (tcg_inst_env env) dfuns
env' = env { tcg_inst_env = inst_env' }
; setGblEnv env' thing_inside }