{-# LANGUAGE BangPatterns, CPP, FlexibleInstances, KindSignatures,
    ScopedTypeVariables, TypeOperators, TypeSynonymInstances #-}
{-# LANGUAGE Safe #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

#if __GLASGOW_HASKELL__ >= 800
#define HAS_DATA_KIND
#endif

-----------------------------------------------------------------------------
-- |
-- Module      : Data.Binary.Generic
-- Copyright   : Bryan O'Sullivan
-- License     : BSD3-style (see LICENSE)
--
-- Maintainer  : Bryan O'Sullivan <bos@serpentine.com>
-- Stability   : unstable
-- Portability : Only works with GHC 7.2 and newer
--
-- Instances for supporting GHC generics.
--
-----------------------------------------------------------------------------
module Data.Binary.Generic
    (
    ) where

import Control.Applicative
import Data.Binary.Class
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits
import Data.Word
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
#ifdef HAS_DATA_KIND
import Data.Kind
#endif
import GHC.Generics
import Prelude -- Silence AMP warning.

-- Type without constructors
instance GBinaryPut V1 where
    gput :: forall t. V1 t -> Put
gput V1 t
_ = () -> Put
forall a. a -> PutM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance GBinaryGet V1 where
    gget :: forall t. Get (V1 t)
gget   = V1 t -> Get (V1 t)
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return V1 t
forall a. HasCallStack => a
undefined

-- Constructor without arguments
instance GBinaryPut U1 where
    gput :: forall t. U1 t -> Put
gput U1 t
U1 = () -> Put
forall a. a -> PutM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance GBinaryGet U1 where
    gget :: forall t. Get (U1 t)
gget    = U1 t -> Get (U1 t)
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return U1 t
forall k (p :: k). U1 p
U1

-- Product: constructor with parameters
instance (GBinaryPut a, GBinaryPut b) => GBinaryPut (a :*: b) where
    gput :: forall t. (:*:) a b t -> Put
gput (a t
x :*: b t
y) = a t -> Put
forall t. a t -> Put
forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput a t
x Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> b t -> Put
forall t. b t -> Put
forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput b t
y

instance (GBinaryGet a, GBinaryGet b) => GBinaryGet (a :*: b) where
    gget :: forall t. Get ((:*:) a b t)
gget = a t -> b t -> (:*:) a b t
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (a t -> b t -> (:*:) a b t)
-> Get (a t) -> Get (b t -> (:*:) a b t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (a t)
forall t. Get (a t)
forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget Get (b t -> (:*:) a b t) -> Get (b t) -> Get ((:*:) a b t)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get (b t)
forall t. Get (b t)
forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

-- Metadata (constructor name, etc)
instance GBinaryPut a => GBinaryPut (M1 i c a) where
    gput :: forall t. M1 i c a t -> Put
gput = a t -> Put
forall t. a t -> Put
forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput (a t -> Put) -> (M1 i c a t -> a t) -> M1 i c a t -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M1 i c a t -> a t
forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

instance GBinaryGet a => GBinaryGet (M1 i c a) where
    gget :: forall t. Get (M1 i c a t)
gget = a t -> M1 i c a t
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (a t -> M1 i c a t) -> Get (a t) -> Get (M1 i c a t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (a t)
forall t. Get (a t)
forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

-- Constants, additional parameters, and rank-1 recursion
instance Binary a => GBinaryPut (K1 i a) where
    gput :: forall t. K1 i a t -> Put
gput = a -> Put
forall t. Binary t => t -> Put
put (a -> Put) -> (K1 i a t -> a) -> K1 i a t -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 i a t -> a
forall k i c (p :: k). K1 i c p -> c
unK1

instance Binary a => GBinaryGet (K1 i a) where
    gget :: forall t. Get (K1 i a t)
gget = a -> K1 i a t
forall k i c (p :: k). c -> K1 i c p
K1 (a -> K1 i a t) -> Get a -> Get (K1 i a t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a
forall t. Binary t => Get t
get

-- Borrowed from the cereal package.

-- The following GBinary instance for sums has support for serializing
-- types with up to 2^64-1 constructors. It will use the minimal
-- number of bytes needed to encode the constructor. For example when
-- a type has 2^8 constructors or less it will use a single byte to
-- encode the constructor. If it has 2^16 constructors or less it will
-- use two bytes, and so on till 2^64-1.

#define GUARD(WORD) (size - 1) <= fromIntegral (maxBound :: WORD)
#define PUTSUM(WORD) GUARD(WORD) = putSum (0 :: WORD) (fromIntegral size)
#define GETSUM(WORD) GUARD(WORD) = (get :: Get WORD) >>= checkGetSum (fromIntegral size)

instance ( GSumPut  a, GSumPut  b
         , SumSize    a, SumSize    b) => GBinaryPut (a :+: b) where
    gput :: forall t. (:+:) a b t -> Put
gput | PUTSUM(Word8) | PUTSUM(Word16) | PUTSUM(Word32) | PUTSUM(Word64)
         | Bool
otherwise = String -> Word64 -> (:+:) a b t -> Put
forall size error. Show size => String -> size -> error
sizeError String
"encode" Word64
size
      where
        size :: Word64
size = Tagged (a :+: b) Word64 -> Word64
forall (s :: * -> *) b. Tagged s b -> b
unTagged (Tagged (a :+: b) Word64
forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged (a :+: b) Word64)

instance ( GSumGet  a, GSumGet  b
         , SumSize    a, SumSize    b) => GBinaryGet (a :+: b) where
    gget :: forall t. Get ((:+:) a b t)
gget | GETSUM(Word8) | GETSUM(Word16) | GETSUM(Word32) | GETSUM(Word64)
         | Bool
otherwise = String -> Word64 -> Get ((:+:) a b t)
forall size error. Show size => String -> size -> error
sizeError String
"decode" Word64
size
      where
        size :: Word64
size = Tagged (a :+: b) Word64 -> Word64
forall (s :: * -> *) b. Tagged s b -> b
unTagged (Tagged (a :+: b) Word64
forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged (a :+: b) Word64)

sizeError :: Show size => String -> size -> error
sizeError :: forall size error. Show size => String -> size -> error
sizeError String
s size
size =
    String -> error
forall a. HasCallStack => String -> a
error (String -> error) -> String -> error
forall a b. (a -> b) -> a -> b
$ String
"Can't " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" a type with " String -> String -> String
forall a. [a] -> [a] -> [a]
++ size -> String
forall a. Show a => a -> String
show size
size String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" constructors"

------------------------------------------------------------------------

checkGetSum :: (Ord word, Num word, Bits word, GSumGet f)
            => word -> word -> Get (f a)
checkGetSum :: forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumGet f) =>
word -> word -> Get (f a)
checkGetSum word
size word
code | word
code word -> word -> Bool
forall a. Ord a => a -> a -> Bool
< word
size = word -> word -> Get (f a)
forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code word
size
                      | Bool
otherwise   = String -> Get (f a)
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unknown encoding for constructor"
{-# INLINE checkGetSum #-}

class GSumGet f where
    getSum :: (Ord word, Num word, Bits word) => word -> word -> Get (f a)

class GSumPut f where
    putSum :: (Num w, Bits w, Binary w) => w -> w -> f a -> Put

instance (GSumGet a, GSumGet b) => GSumGet (a :+: b) where
    getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get ((:+:) a b a)
getSum !word
code !word
size | word
code word -> word -> Bool
forall a. Ord a => a -> a -> Bool
< word
sizeL = a a -> (:+:) a b a
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (a a -> (:+:) a b a) -> Get (a a) -> Get ((:+:) a b a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> word -> word -> Get (a a)
forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (a a)
forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code           word
sizeL
                       | Bool
otherwise    = b a -> (:+:) a b a
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (b a -> (:+:) a b a) -> Get (b a) -> Get ((:+:) a b a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> word -> word -> Get (b a)
forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (b a)
forall (f :: * -> *) word a.
(GSumGet f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum (word
code word -> word -> word
forall a. Num a => a -> a -> a
- word
sizeL) word
sizeR
        where
          sizeL :: word
sizeL = word
size word -> Int -> word
forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          sizeR :: word
sizeR = word
size word -> word -> word
forall a. Num a => a -> a -> a
- word
sizeL

instance (GSumPut a, GSumPut b) => GSumPut (a :+: b) where
    putSum :: forall w a.
(Num w, Bits w, Binary w) =>
w -> w -> (:+:) a b a -> Put
putSum !w
code !w
size (:+:) a b a
s = case (:+:) a b a
s of
                             L1 a a
x -> w -> w -> a a -> Put
forall w a. (Num w, Bits w, Binary w) => w -> w -> a a -> Put
forall (f :: * -> *) w a.
(GSumPut f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum w
code           w
sizeL a a
x
                             R1 b a
x -> w -> w -> b a -> Put
forall w a. (Num w, Bits w, Binary w) => w -> w -> b a -> Put
forall (f :: * -> *) w a.
(GSumPut f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (w
code w -> w -> w
forall a. Num a => a -> a -> a
+ w
sizeL) w
sizeR b a
x
        where
          sizeL :: w
sizeL = w
size w -> Int -> w
forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          sizeR :: w
sizeR = w
size w -> w -> w
forall a. Num a => a -> a -> a
- w
sizeL

instance GBinaryGet a => GSumGet (C1 c a) where
    getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (C1 c a a)
getSum word
_ word
_ = Get (M1 C c a a)
forall t. Get (M1 C c a t)
forall {k} (f :: k -> *) (t :: k). GBinaryGet f => Get (f t)
gget

instance GBinaryPut a => GSumPut (C1 c a) where
    putSum :: forall w a. (Num w, Bits w, Binary w) => w -> w -> C1 c a a -> Put
putSum !w
code w
_ C1 c a a
x = w -> Put
forall t. Binary t => t -> Put
put w
code Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> C1 c a a -> Put
forall t. M1 C c a t -> Put
forall {k} (f :: k -> *) (t :: k). GBinaryPut f => f t -> Put
gput C1 c a a
x

------------------------------------------------------------------------

class SumSize f where
    sumSize :: Tagged f Word64

#ifdef HAS_DATA_KIND
newtype Tagged (s :: Type -> Type) b = Tagged {forall (s :: * -> *) b. Tagged s b -> b
unTagged :: b}
#else
newtype Tagged (s :: * -> *)       b = Tagged {unTagged :: b}
#endif

instance (SumSize a, SumSize b) => SumSize (a :+: b) where
    sumSize :: Tagged (a :+: b) Word64
sumSize = Word64 -> Tagged (a :+: b) Word64
forall (s :: * -> *) b. b -> Tagged s b
Tagged (Word64 -> Tagged (a :+: b) Word64)
-> Word64 -> Tagged (a :+: b) Word64
forall a b. (a -> b) -> a -> b
$ Tagged a Word64 -> Word64
forall (s :: * -> *) b. Tagged s b -> b
unTagged (Tagged a Word64
forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged a Word64) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+
                       Tagged b Word64 -> Word64
forall (s :: * -> *) b. Tagged s b -> b
unTagged (Tagged b Word64
forall (f :: * -> *). SumSize f => Tagged f Word64
sumSize :: Tagged b Word64)

instance SumSize (C1 c a) where
    sumSize :: Tagged (C1 c a) Word64
sumSize = Word64 -> Tagged (C1 c a) Word64
forall (s :: * -> *) b. b -> Tagged s b
Tagged Word64
1