module TcTypeNats
( typeNatTyCons
, typeNatCoAxiomRules
, BuiltInSynFamily(..)
, typeNatAddTyCon
, typeNatMulTyCon
, typeNatExpTyCon
, typeNatLeqTyCon
, typeNatSubTyCon
, typeNatCmpTyCon
, typeSymbolCmpTyCon
) where
import Type
import Pair
import TcType ( TcType, tcEqType )
import TyCon ( TyCon, FamTyConFlav(..), mkFamilyTyCon, TyConParent(..) )
import Coercion ( Role(..) )
import TcRnTypes ( Xi )
import CoAxiom ( CoAxiomRule(..), BuiltInSynFamily(..) )
import Name ( Name, BuiltInSyntax(..) )
import TysWiredIn ( typeNatKind, typeSymbolKind
, mkWiredInTyConName
, promotedBoolTyCon
, promotedFalseDataCon, promotedTrueDataCon
, promotedOrderingTyCon
, promotedLTDataCon
, promotedEQDataCon
, promotedGTDataCon
)
import TysPrim ( tyVarList, mkArrowKinds )
import PrelNames ( gHC_TYPELITS
, typeNatAddTyFamNameKey
, typeNatMulTyFamNameKey
, typeNatExpTyFamNameKey
, typeNatLeqTyFamNameKey
, typeNatSubTyFamNameKey
, typeNatCmpTyFamNameKey
, typeSymbolCmpTyFamNameKey
)
import FastString ( FastString, fsLit )
import qualified Data.Map as Map
import Data.Maybe ( isJust )
typeNatTyCons :: [TyCon]
typeNatTyCons =
[ typeNatAddTyCon
, typeNatMulTyCon
, typeNatExpTyCon
, typeNatLeqTyCon
, typeNatSubTyCon
, typeNatCmpTyCon
, typeSymbolCmpTyCon
]
typeNatAddTyCon :: TyCon
typeNatAddTyCon = mkTypeNatFunTyCon2 name
BuiltInSynFamily
{ sfMatchFam = matchFamAdd
, sfInteractTop = interactTopAdd
, sfInteractInert = interactInertAdd
}
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "+")
typeNatAddTyFamNameKey typeNatAddTyCon
typeNatSubTyCon :: TyCon
typeNatSubTyCon = mkTypeNatFunTyCon2 name
BuiltInSynFamily
{ sfMatchFam = matchFamSub
, sfInteractTop = interactTopSub
, sfInteractInert = interactInertSub
}
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "-")
typeNatSubTyFamNameKey typeNatSubTyCon
typeNatMulTyCon :: TyCon
typeNatMulTyCon = mkTypeNatFunTyCon2 name
BuiltInSynFamily
{ sfMatchFam = matchFamMul
, sfInteractTop = interactTopMul
, sfInteractInert = interactInertMul
}
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "*")
typeNatMulTyFamNameKey typeNatMulTyCon
typeNatExpTyCon :: TyCon
typeNatExpTyCon = mkTypeNatFunTyCon2 name
BuiltInSynFamily
{ sfMatchFam = matchFamExp
, sfInteractTop = interactTopExp
, sfInteractInert = interactInertExp
}
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "^")
typeNatExpTyFamNameKey typeNatExpTyCon
typeNatLeqTyCon :: TyCon
typeNatLeqTyCon =
mkFamilyTyCon name
(mkArrowKinds [ typeNatKind, typeNatKind ] boolKind)
(take 2 $ tyVarList typeNatKind)
(BuiltInSynFamTyCon ops)
NoParentTyCon
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "<=?")
typeNatLeqTyFamNameKey typeNatLeqTyCon
ops = BuiltInSynFamily
{ sfMatchFam = matchFamLeq
, sfInteractTop = interactTopLeq
, sfInteractInert = interactInertLeq
}
typeNatCmpTyCon :: TyCon
typeNatCmpTyCon =
mkFamilyTyCon name
(mkArrowKinds [ typeNatKind, typeNatKind ] orderingKind)
(take 2 $ tyVarList typeNatKind)
(BuiltInSynFamTyCon ops)
NoParentTyCon
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "CmpNat")
typeNatCmpTyFamNameKey typeNatCmpTyCon
ops = BuiltInSynFamily
{ sfMatchFam = matchFamCmpNat
, sfInteractTop = interactTopCmpNat
, sfInteractInert = \_ _ _ _ -> []
}
typeSymbolCmpTyCon :: TyCon
typeSymbolCmpTyCon =
mkFamilyTyCon name
(mkArrowKinds [ typeSymbolKind, typeSymbolKind ] orderingKind)
(take 2 $ tyVarList typeSymbolKind)
(BuiltInSynFamTyCon ops)
NoParentTyCon
where
name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "CmpSymbol")
typeSymbolCmpTyFamNameKey typeSymbolCmpTyCon
ops = BuiltInSynFamily
{ sfMatchFam = matchFamCmpSymbol
, sfInteractTop = interactTopCmpSymbol
, sfInteractInert = \_ _ _ _ -> []
}
mkTypeNatFunTyCon2 :: Name -> BuiltInSynFamily -> TyCon
mkTypeNatFunTyCon2 op tcb =
mkFamilyTyCon op
(mkArrowKinds [ typeNatKind, typeNatKind ] typeNatKind)
(take 2 $ tyVarList typeNatKind)
(BuiltInSynFamTyCon tcb)
NoParentTyCon
axAddDef
, axMulDef
, axExpDef
, axLeqDef
, axCmpNatDef
, axCmpSymbolDef
, axAdd0L
, axAdd0R
, axMul0L
, axMul0R
, axMul1L
, axMul1R
, axExp1L
, axExp0R
, axExp1R
, axLeqRefl
, axCmpNatRefl
, axCmpSymbolRefl
, axLeq0L
, axSubDef
, axSub0R
:: CoAxiomRule
axAddDef = mkBinAxiom "AddDef" typeNatAddTyCon $
\x y -> Just $ num (x + y)
axMulDef = mkBinAxiom "MulDef" typeNatMulTyCon $
\x y -> Just $ num (x * y)
axExpDef = mkBinAxiom "ExpDef" typeNatExpTyCon $
\x y -> Just $ num (x ^ y)
axLeqDef = mkBinAxiom "LeqDef" typeNatLeqTyCon $
\x y -> Just $ bool (x <= y)
axCmpNatDef = mkBinAxiom "CmpNatDef" typeNatCmpTyCon
$ \x y -> Just $ ordering (compare x y)
axCmpSymbolDef =
CoAxiomRule
{ coaxrName = fsLit "CmpSymbolDef"
, coaxrTypeArity = 2
, coaxrAsmpRoles = []
, coaxrRole = Nominal
, coaxrProves = \ts cs ->
case (ts,cs) of
([s,t],[]) ->
do x <- isStrLitTy s
y <- isStrLitTy t
return (mkTyConApp typeSymbolCmpTyCon [s,t] ===
ordering (compare x y))
_ -> Nothing
}
axSubDef = mkBinAxiom "SubDef" typeNatSubTyCon $
\x y -> fmap num (minus x y)
axAdd0L = mkAxiom1 "Add0L" $ \t -> (num 0 .+. t) === t
axAdd0R = mkAxiom1 "Add0R" $ \t -> (t .+. num 0) === t
axSub0R = mkAxiom1 "Sub0R" $ \t -> (t .-. num 0) === t
axMul0L = mkAxiom1 "Mul0L" $ \t -> (num 0 .*. t) === num 0
axMul0R = mkAxiom1 "Mul0R" $ \t -> (t .*. num 0) === num 0
axMul1L = mkAxiom1 "Mul1L" $ \t -> (num 1 .*. t) === t
axMul1R = mkAxiom1 "Mul1R" $ \t -> (t .*. num 1) === t
axExp1L = mkAxiom1 "Exp1L" $ \t -> (num 1 .^. t) === num 1
axExp0R = mkAxiom1 "Exp0R" $ \t -> (t .^. num 0) === num 1
axExp1R = mkAxiom1 "Exp1R" $ \t -> (t .^. num 1) === t
axLeqRefl = mkAxiom1 "LeqRefl" $ \t -> (t <== t) === bool True
axCmpNatRefl = mkAxiom1 "CmpNatRefl"
$ \t -> (cmpNat t t) === ordering EQ
axCmpSymbolRefl = mkAxiom1 "CmpSymbolRefl"
$ \t -> (cmpSymbol t t) === ordering EQ
axLeq0L = mkAxiom1 "Leq0L" $ \t -> (num 0 <== t) === bool True
typeNatCoAxiomRules :: Map.Map FastString CoAxiomRule
typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
[ axAddDef
, axMulDef
, axExpDef
, axLeqDef
, axCmpNatDef
, axCmpSymbolDef
, axAdd0L
, axAdd0R
, axMul0L
, axMul0R
, axMul1L
, axMul1R
, axExp1L
, axExp0R
, axExp1R
, axLeqRefl
, axCmpNatRefl
, axCmpSymbolRefl
, axLeq0L
, axSubDef
]
(.+.) :: Type -> Type -> Type
s .+. t = mkTyConApp typeNatAddTyCon [s,t]
(.-.) :: Type -> Type -> Type
s .-. t = mkTyConApp typeNatSubTyCon [s,t]
(.*.) :: Type -> Type -> Type
s .*. t = mkTyConApp typeNatMulTyCon [s,t]
(.^.) :: Type -> Type -> Type
s .^. t = mkTyConApp typeNatExpTyCon [s,t]
(<==) :: Type -> Type -> Type
s <== t = mkTyConApp typeNatLeqTyCon [s,t]
cmpNat :: Type -> Type -> Type
cmpNat s t = mkTyConApp typeNatCmpTyCon [s,t]
cmpSymbol :: Type -> Type -> Type
cmpSymbol s t = mkTyConApp typeSymbolCmpTyCon [s,t]
(===) :: Type -> Type -> Pair Type
x === y = Pair x y
num :: Integer -> Type
num = mkNumLitTy
boolKind :: Kind
boolKind = mkTyConApp promotedBoolTyCon []
bool :: Bool -> Type
bool b = if b then mkTyConApp promotedTrueDataCon []
else mkTyConApp promotedFalseDataCon []
isBoolLitTy :: Type -> Maybe Bool
isBoolLitTy tc =
do (tc,[]) <- splitTyConApp_maybe tc
case () of
_ | tc == promotedFalseDataCon -> return False
| tc == promotedTrueDataCon -> return True
| otherwise -> Nothing
orderingKind :: Kind
orderingKind = mkTyConApp promotedOrderingTyCon []
ordering :: Ordering -> Type
ordering o =
case o of
LT -> mkTyConApp promotedLTDataCon []
EQ -> mkTyConApp promotedEQDataCon []
GT -> mkTyConApp promotedGTDataCon []
isOrderingLitTy :: Type -> Maybe Ordering
isOrderingLitTy tc =
do (tc1,[]) <- splitTyConApp_maybe tc
case () of
_ | tc1 == promotedLTDataCon -> return LT
| tc1 == promotedEQDataCon -> return EQ
| tc1 == promotedGTDataCon -> return GT
| otherwise -> Nothing
known :: (Integer -> Bool) -> TcType -> Bool
known p x = case isNumLitTy x of
Just a -> p a
Nothing -> False
mkBinAxiom :: String -> TyCon ->
(Integer -> Integer -> Maybe Type) -> CoAxiomRule
mkBinAxiom str tc f =
CoAxiomRule
{ coaxrName = fsLit str
, coaxrTypeArity = 2
, coaxrAsmpRoles = []
, coaxrRole = Nominal
, coaxrProves = \ts cs ->
case (ts,cs) of
([s,t],[]) -> do x <- isNumLitTy s
y <- isNumLitTy t
z <- f x y
return (mkTyConApp tc [s,t] === z)
_ -> Nothing
}
mkAxiom1 :: String -> (Type -> Pair Type) -> CoAxiomRule
mkAxiom1 str f =
CoAxiomRule
{ coaxrName = fsLit str
, coaxrTypeArity = 1
, coaxrAsmpRoles = []
, coaxrRole = Nominal
, coaxrProves = \ts cs ->
case (ts,cs) of
([s],[]) -> return (f s)
_ -> Nothing
}
matchFamAdd :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamAdd [s,t]
| Just 0 <- mbX = Just (axAdd0L, [t], t)
| Just 0 <- mbY = Just (axAdd0R, [s], s)
| Just x <- mbX, Just y <- mbY =
Just (axAddDef, [s,t], num (x + y))
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamAdd _ = Nothing
matchFamSub :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamSub [s,t]
| Just 0 <- mbY = Just (axSub0R, [s], s)
| Just x <- mbX, Just y <- mbY, Just z <- minus x y =
Just (axSubDef, [s,t], num z)
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamSub _ = Nothing
matchFamMul :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamMul [s,t]
| Just 0 <- mbX = Just (axMul0L, [t], num 0)
| Just 0 <- mbY = Just (axMul0R, [s], num 0)
| Just 1 <- mbX = Just (axMul1L, [t], t)
| Just 1 <- mbY = Just (axMul1R, [s], s)
| Just x <- mbX, Just y <- mbY =
Just (axMulDef, [s,t], num (x * y))
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamMul _ = Nothing
matchFamExp :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamExp [s,t]
| Just 0 <- mbY = Just (axExp0R, [s], num 1)
| Just 1 <- mbX = Just (axExp1L, [t], num 1)
| Just 1 <- mbY = Just (axExp1R, [s], s)
| Just x <- mbX, Just y <- mbY =
Just (axExpDef, [s,t], num (x ^ y))
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamExp _ = Nothing
matchFamLeq :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamLeq [s,t]
| Just 0 <- mbX = Just (axLeq0L, [t], bool True)
| Just x <- mbX, Just y <- mbY =
Just (axLeqDef, [s,t], bool (x <= y))
| tcEqType s t = Just (axLeqRefl, [s], bool True)
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamLeq _ = Nothing
matchFamCmpNat :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamCmpNat [s,t]
| Just x <- mbX, Just y <- mbY =
Just (axCmpNatDef, [s,t], ordering (compare x y))
| tcEqType s t = Just (axCmpNatRefl, [s], ordering EQ)
where mbX = isNumLitTy s
mbY = isNumLitTy t
matchFamCmpNat _ = Nothing
matchFamCmpSymbol :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
matchFamCmpSymbol [s,t]
| Just x <- mbX, Just y <- mbY =
Just (axCmpSymbolDef, [s,t], ordering (compare x y))
| tcEqType s t = Just (axCmpSymbolRefl, [s], ordering EQ)
where mbX = isStrLitTy s
mbY = isStrLitTy t
matchFamCmpSymbol _ = Nothing
interactTopAdd :: [Xi] -> Xi -> [Pair Type]
interactTopAdd [s,t] r
| Just 0 <- mbZ = [ s === num 0, t === num 0 ]
| Just x <- mbX, Just z <- mbZ, Just y <- minus z x = [t === num y]
| Just y <- mbY, Just z <- mbZ, Just x <- minus z y = [s === num x]
where
mbX = isNumLitTy s
mbY = isNumLitTy t
mbZ = isNumLitTy r
interactTopAdd _ _ = []
interactTopSub :: [Xi] -> Xi -> [Pair Type]
interactTopSub [s,t] r
| Just z <- mbZ = [ s === (num z .+. t) ]
where
mbZ = isNumLitTy r
interactTopSub _ _ = []
interactTopMul :: [Xi] -> Xi -> [Pair Type]
interactTopMul [s,t] r
| Just 1 <- mbZ = [ s === num 1, t === num 1 ]
| Just x <- mbX, Just z <- mbZ, Just y <- divide z x = [t === num y]
| Just y <- mbY, Just z <- mbZ, Just x <- divide z y = [s === num x]
where
mbX = isNumLitTy s
mbY = isNumLitTy t
mbZ = isNumLitTy r
interactTopMul _ _ = []
interactTopExp :: [Xi] -> Xi -> [Pair Type]
interactTopExp [s,t] r
| Just 0 <- mbZ = [ s === num 0 ]
| Just x <- mbX, Just z <- mbZ, Just y <- logExact z x = [t === num y]
| Just y <- mbY, Just z <- mbZ, Just x <- rootExact z y = [s === num x]
where
mbX = isNumLitTy s
mbY = isNumLitTy t
mbZ = isNumLitTy r
interactTopExp _ _ = []
interactTopLeq :: [Xi] -> Xi -> [Pair Type]
interactTopLeq [s,t] r
| Just 0 <- mbY, Just True <- mbZ = [ s === num 0 ]
where
mbY = isNumLitTy t
mbZ = isBoolLitTy r
interactTopLeq _ _ = []
interactTopCmpNat :: [Xi] -> Xi -> [Pair Type]
interactTopCmpNat [s,t] r
| Just EQ <- isOrderingLitTy r = [ s === t ]
interactTopCmpNat _ _ = []
interactTopCmpSymbol :: [Xi] -> Xi -> [Pair Type]
interactTopCmpSymbol [s,t] r
| Just EQ <- isOrderingLitTy r = [ s === t ]
interactTopCmpSymbol _ _ = []
interactInertAdd :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
interactInertAdd [x1,y1] z1 [x2,y2] z2
| sameZ && tcEqType x1 x2 = [ y1 === y2 ]
| sameZ && tcEqType y1 y2 = [ x1 === x2 ]
where sameZ = tcEqType z1 z2
interactInertAdd _ _ _ _ = []
interactInertSub :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
interactInertSub [x1,y1] z1 [x2,y2] z2
| sameZ && tcEqType x1 x2 = [ y1 === y2 ]
| sameZ && tcEqType y1 y2 = [ x1 === x2 ]
where sameZ = tcEqType z1 z2
interactInertSub _ _ _ _ = []
interactInertMul :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
interactInertMul [x1,y1] z1 [x2,y2] z2
| sameZ && known (/= 0) x1 && tcEqType x1 x2 = [ y1 === y2 ]
| sameZ && known (/= 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
where sameZ = tcEqType z1 z2
interactInertMul _ _ _ _ = []
interactInertExp :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
interactInertExp [x1,y1] z1 [x2,y2] z2
| sameZ && known (> 1) x1 && tcEqType x1 x2 = [ y1 === y2 ]
| sameZ && known (> 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
where sameZ = tcEqType z1 z2
interactInertExp _ _ _ _ = []
interactInertLeq :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
interactInertLeq [x1,y1] z1 [x2,y2] z2
| bothTrue && tcEqType x1 y2 && tcEqType y1 x2 = [ x1 === y1 ]
| bothTrue && tcEqType y1 x2 = [ (x1 <== y2) === bool True ]
| bothTrue && tcEqType y2 x1 = [ (x2 <== y1) === bool True ]
where bothTrue = isJust $ do True <- isBoolLitTy z1
True <- isBoolLitTy z2
return ()
interactInertLeq _ _ _ _ = []
minus :: Integer -> Integer -> Maybe Integer
minus x y = if x >= y then Just (x y) else Nothing
logExact :: Integer -> Integer -> Maybe Integer
logExact x y = do (z,True) <- genLog x y
return z
divide :: Integer -> Integer -> Maybe Integer
divide _ 0 = Nothing
divide x y = case divMod x y of
(a,0) -> Just a
_ -> Nothing
rootExact :: Integer -> Integer -> Maybe Integer
rootExact x y = do (z,True) <- genRoot x y
return z
genRoot :: Integer -> Integer -> Maybe (Integer, Bool)
genRoot _ 0 = Nothing
genRoot x0 1 = Just (x0, True)
genRoot x0 root = Just (search 0 (x0+1))
where
search from to = let x = from + div (to from) 2
a = x ^ root
in case compare a x0 of
EQ -> (x, True)
LT | x /= from -> search x to
| otherwise -> (from, False)
GT | x /= to -> search from x
| otherwise -> (from, False)
genLog :: Integer -> Integer -> Maybe (Integer, Bool)
genLog x 0 = if x == 1 then Just (0, True) else Nothing
genLog _ 1 = Nothing
genLog 0 _ = Nothing
genLog x base = Just (exactLoop 0 x)
where
exactLoop s i
| i == 1 = (s,True)
| i < base = (s,False)
| otherwise =
let s1 = s + 1
in s1 `seq` case divMod i base of
(j,r)
| r == 0 -> exactLoop s1 j
| otherwise -> (underLoop s1 j, False)
underLoop s i
| i < base = s
| otherwise = let s1 = s + 1 in s1 `seq` underLoop s1 (div i base)