{-# LANGUAGE BangPatterns, CPP, FlexibleInstances, KindSignatures,
    ScopedTypeVariables, TypeOperators, TypeSynonymInstances #-}
{-# LANGUAGE Safe #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

#if __GLASGOW_HASKELL__ >= 800
#define HAS_DATA_KIND
#endif

-----------------------------------------------------------------------------
-- |
-- Module      : Data.Binary.Generic
-- Copyright   : Bryan O'Sullivan
-- License     : BSD3-style (see LICENSE)
--
-- Maintainer  : Bryan O'Sullivan <bos@serpentine.com>
-- Stability   : unstable
-- Portability : Only works with GHC 7.2 and newer
--
-- Instances for supporting GHC generics.
--
-----------------------------------------------------------------------------
module Data.Binary.Generic
    (
    ) where

import Control.Applicative
import Data.Binary.Class
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits
import Data.Word
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
#ifdef HAS_DATA_KIND
import Data.Kind
#endif
import GHC.Generics
import Prelude -- Silence AMP warning.

-- Type without constructors
instance GBinaryPut V1 where
    gput :: forall t. V1 t -> Put
gput V1 t
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance GBinaryGet V1 where
    gget :: forall t. Get (V1 t)
gget   = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. HasCallStack => a
undefined

-- Constructor without arguments
instance GBinaryPut U1 where
    gput :: forall t. U1 t -> Put
gput U1 t
U1 = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance GBinaryGet U1 where
    gget :: forall t. Get (U1 t)
gget    = forall (m :: * -> *) a. Monad m => a -> m a
return forall k (p :: k). U1 p
U1

-- Product: constructor with parameters
instance (GBinaryPut a, GBinaryPut b) => GBinaryPut (a :*: b) where
    gput :: forall t. (:*:) a b t -> Put
gput (a t
x :*: b t
y) = forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput a t
x forall a. Semigroup a => a -> a -> a
<> forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput b t
y

instance (GBinaryGet a, GBinaryGet b) => GBinaryGet (a :*: b) where
    gget :: forall t. Get ((:*:) a b t)
gget = forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

-- Metadata (constructor name, etc)
instance GBinaryPut a => GBinaryPut (M1 i c a) where
    gput :: forall t. M1 i c a t -> Put
gput = forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

instance GBinaryGet a => GBinaryGet (M1 i c a) where
    gget :: forall t. Get (M1 i c a t)
gget = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

-- Constants, additional parameters, and rank-1 recursion
instance Binary a => GBinaryPut (K1 i a) where
    gput :: forall t. K1 i a t -> Put
gput = forall t. Binary t => t -> Put
put forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

instance Binary a => GBinaryGet (K1 i a) where
    gget :: forall t. Get (K1 i a t)
gget = forall k i c (p :: k). c -> K1 i c p
K1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall t. Binary t => Get t
get

-- Borrowed from the cereal package.

-- The following GBinary instance for sums has support for serializing
-- types with up to 2^64-1 constructors. It will use the minimal
-- number of bytes needed to encode the constructor. For example when
-- a type has 2^8 constructors or less it will use a single byte to
-- encode the constructor. If it has 2^16 constructors or less it will
-- use two bytes, and so on till 2^64-1.

#define GUARD(WORD) (size - 1) <= fromIntegral (maxBound :: WORD)
#define PUTSUM(WORD) GUARD(WORD) = putSum (0 :: WORD) (fromIntegral size)
#define GETSUM(WORD) GUARD(WORD) = (get :: Get WORD) >>= checkGetSum (fromIntegral size)

instance ( GSumPut  a, GSumPut  b
         , SumSize    a, SumSize    b) => GBinaryPut (a :+: b) where
    gput :: forall t. (:+:) a b t -> Put
gput | PUTSUM(Word8) | PUTSUM(Word16) | PUTSUM(Word32) | PUTSUM(Word64)
         | Bool
otherwise = forall size error. Show size => String -> size -> error
sizeError String
"encode" Word64
size
      where
        size :: Word64
size = forall (s :: * -> *) b. Tagged s b -> b
unTagged (forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged (a :+: b) Word64)

instance ( GSumGet  a, GSumGet  b
         , SumSize    a, SumSize    b) => GBinaryGet (a :+: b) where
    gget :: forall t. Get ((:+:) a b t)
gget | GETSUM(Word8) | GETSUM(Word16) | GETSUM(Word32) | GETSUM(Word64)
         | Bool
otherwise = forall size error. Show size => String -> size -> error
sizeError String
"decode" Word64
size
      where
        size :: Word64
size = forall (s :: * -> *) b. Tagged s b -> b
unTagged (forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged (a :+: b) Word64)

sizeError :: Show size => String -> size -> error
sizeError :: forall size error. Show size => String -> size -> error
sizeError String
s size
size =
    forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Can't " forall a. [a] -> [a] -> [a]
++ String
s forall a. [a] -> [a] -> [a]
++ String
" a type with " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show size
size forall a. [a] -> [a] -> [a]
++ String
" constructors"

------------------------------------------------------------------------

checkGetSum :: (Ord word, Num word, Bits word, GSumGet f)
            => word -> word -> Get (f a)
checkGetSum :: forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumGet f) =>
word -> word -> Get (f a)
checkGetSum word
size word
code | word
code forall a. Ord a => a -> a -> Bool
< word
size = forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code word
size
                      | Bool
otherwise   = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unknown encoding for constructor"
{-# INLINE checkGetSum #-}

class GSumGet f where
    getSum :: (Ord word, Num word, Bits word) => word -> word -> Get (f a)

class GSumPut f where
    putSum :: (Num w, Bits w, Binary w) => w -> w -> f a -> Put

instance (GSumGet a, GSumGet b) => GSumGet (a :+: b) where
    getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get ((:+:) a b a)
getSum !word
code !word
size | word
code forall a. Ord a => a -> a -> Bool
< word
sizeL = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code           word
sizeL
                       | Bool
otherwise    = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum (word
code forall a. Num a => a -> a -> a
- word
sizeL) word
sizeR
        where
          sizeL :: word
sizeL = word
size forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          sizeR :: word
sizeR = word
size forall a. Num a => a -> a -> a
- word
sizeL

instance (GSumPut a, GSumPut b) => GSumPut (a :+: b) where
    putSum :: forall w a.
(Num w, Bits w, Binary w) =>
w -> w -> (:+:) a b a -> Put
putSum !w
code !w
size (:+:) a b a
s = case (:+:) a b a
s of
                             L1 a a
x -> forall (f :: * -> *) w a.
(GSumPut f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum w
code           w
sizeL a a
x
                             R1 b a
x -> forall (f :: * -> *) w a.
(GSumPut f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (w
code forall a. Num a => a -> a -> a
+ w
sizeL) w
sizeR b a
x
        where
          sizeL :: w
sizeL = w
size forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          sizeR :: w
sizeR = w
size forall a. Num a => a -> a -> a
- w
sizeL

instance GBinaryGet a => GSumGet (C1 c a) where
    getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (C1 c a a)
getSum word
_ word
_ = forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

instance GBinaryPut a => GSumPut (C1 c a) where
    putSum :: forall w a. (Num w, Bits w, Binary w) => w -> w -> C1 c a a -> Put
putSum !w
code w
_ C1 c a a
x = forall t. Binary t => t -> Put
put w
code forall a. Semigroup a => a -> a -> a
<> forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput C1 c a a
x

------------------------------------------------------------------------

class SumSize f where
    sumSize :: Tagged f Word64

#ifdef HAS_DATA_KIND
newtype Tagged (s :: Type -> Type) b = Tagged {forall (s :: * -> *) b. Tagged s b -> b
unTagged :: b}
#else
newtype Tagged (s :: * -> *)       b = Tagged {unTagged :: b}
#endif

instance (SumSize a, SumSize b) => SumSize (a :+: b) where
    sumSize :: Tagged (a :+: b) Word64
sumSize = forall (s :: * -> *) b. b -> Tagged s b
Tagged forall a b. (a -> b) -> a -> b
$ forall (s :: * -> *) b. Tagged s b -> b
unTagged (forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged a Word64) forall a. Num a => a -> a -> a
+
                       forall (s :: * -> *) b. Tagged s b -> b
unTagged (forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged b Word64)

instance SumSize (C1 c a) where
    sumSize :: Tagged (C1 c a) Word64
sumSize = forall (s :: * -> *) b. b -> Tagged s b
Tagged Word64
1