{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns, MagicHash, UnboxedTuples #-}
-- |
-- Module      : Data.ByteString.Builder.RealFloat.TableGenerator
-- Copyright   : (c) Lawrence Wu 2021
-- License     : BSD-style
-- Maintainer  : lawrencejwu@gmail.com
--
-- Constants and overview for compile-time table generation for Ryu internals
--
-- This module uses Haskell's arbitrary-precision `Integer` types to compute
-- the necessary multipliers for efficient conversion to a decimal power base.
--
-- It also exposes constants relevant to the 32- and 64-bit tables (e.g maximum
-- number of bits required to store the table values).

module Data.ByteString.Builder.RealFloat.TableGenerator
  ( float_pow5_inv_bitcount
  , float_pow5_bitcount
  , double_pow5_bitcount
  , double_pow5_inv_bitcount
  , float_max_split
  , float_max_inv_split
  , double_max_split
  , double_max_inv_split

  , finv
  , fnorm
  , splitWord128s
  , case64
  , case128
  ) where

import GHC.Float (int2Double)

import Data.Bits
import Data.Word
import Numeric


-- The basic floating point conversion algorithm is as such:
--
-- Given floating point
--
--   f = (-1)^s * m_f * 2^e_f
--
-- which is IEEE encoded by `[s] [.. e ..] [.. m ..]`. `s` is the sign bit, `e`
-- is the biased exponent, and `m` is the mantissa, let
--
--       | e /= 0            | e == 0
--  -----+-------------------+-----------
--   m_f | 2^len(m) + m      | m
--   e_f | e - bias - len(m) | 1 - bias - len(m)
--
-- we compute the halfway points to the next smaller (`f-`) and larger (`f+`)
-- floating point numbers as
--
--  lower halfway point u * 2^e2, u = 4 * m_f - (if m == 0 then 1 else 2)
--                      v * 2^e2, v = 4 * m_f
--  upper halfway point w * 2^e2, u = 4 * m_f + 2
--  where e2 = ef - 2 (so u, v, w are integers)
--
--
-- Then we compute (a, b, c) * 10^e10 = (u, v, w) * 2^e2 which is split into
-- the case of
--
--   e2 >= 0   ==>    e10 = 0 , (a, b, c) = (u, v, w) * 2^e2
--   e2 <  0   ==>    e10 = e2, (a, b, c) = (u, v, w) * 5^-e2
--
-- And finally we find the shortest representation from integers d0 and e0 such
-- that
--
--  a * 10^e10 < d0 * 10^(e0+e10) < c * 10^e10
--
-- such that e0 is maximal (we allow equality to smaller or larger halfway
-- point depending on rounding mode). This is found through iteratively
-- dividing by 10 while a/10^j < c/10^j and doing some bookkeeping around
-- zeros.
--
--
--
--
-- The ryu algorithm removes the requirement for arbitrary precision arithmetic
-- and improves the runtime significantly by skipping most of the iterative
-- division by carefully selecting a point where certain invariants hold and
-- precomputing a few tables.
--
-- Specifically, define `q` such that the correspondings values of a/10^q <
-- c/10^q - 1. We can prove (not shown) that
--
--    if e2 >= 0, q = e2 * log_10(2)
--    if e2 <  0, q = -e2 * log_10(5)
--
-- Then we can compute (a, b, c) / 10^q. Starting from (u, v, w) we have
--
--      (a, b, c) / 10^q                  (a, b, c) / 10^q
--    = (u, v, w) * 2^e2 / 10^q    OR   = (u, v, w) * 5^-e2 / 10^q
--
-- And since q < e2,
--
--    = (u, v, w) * 2^e2-q / 5^q   OR   = (u, v, w) * 5^-e2-q / 2^q
--
-- While (u, v, w) are n-bit numbers, 5^q and whatnot are significantly larger,
-- but we only need the top-most n bits of the result so we can choose `k` that
-- reduce the number of bits required to ~2n. We then multiply by either
--
--    2^k / 5^q                    OR   5^-e2-q / 2^k
--
-- The required `k` is roughly linear in the exponent (we need more of the
-- multiplication to be precise) but the number of bits to store the
-- multiplicands above stays fixed.
--
-- Since the number of bits needed is relatively small for IEEE 32- and 64-bit
-- floating types, we can compute appropriate values for `k` for the
-- floating-point-type-specific bounds instead of each e2.
--
-- Finally, we need to do some final manual iterations potentially to do a
-- final fixup of the skipped state


-- | Bound for bits of @2^k / 5^q@ for floats
float_pow5_inv_bitcount :: Int
float_pow5_inv_bitcount :: Int
float_pow5_inv_bitcount = Int
59

-- | Bound for bits of @5^-e2-q / 2^k@ for floats
float_pow5_bitcount :: Int
float_pow5_bitcount :: Int
float_pow5_bitcount = Int
61

-- | Bound for bits of @5^-e2-q / 2^k@ for doubles
double_pow5_bitcount :: Int
double_pow5_bitcount :: Int
double_pow5_bitcount = Int
125

-- | Bound for bits of @2^k / 5^q@ for doubles
double_pow5_inv_bitcount :: Int
double_pow5_inv_bitcount :: Int
double_pow5_inv_bitcount = Int
125

-- NB: these tables are encoded directly into the
-- source code in cbits/aligned-static-hs-data.c

-- | Number of bits in a positive integer
blen :: Integer -> Int
blen :: Integer -> Int
blen Integer
0 = Int
0
blen Integer
1 = Int
1
blen Integer
n = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Integer -> Int
blen (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
2)

-- | Used for table generation of 2^k / 5^q + 1
finv :: Int -> Int -> Integer
finv :: Int -> Int -> Integer
finv Int
bitcount Int
i =
  let p :: Integer
p = Integer
5Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^Int
i
   in (Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (Integer -> Int
blen Integer
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
bitcount)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1

-- | Used for table generation of 5^-e2-q / 2^k
fnorm :: Int -> Int -> Integer
fnorm :: Int -> Int -> Integer
fnorm Int
bitcount Int
i =
  let p :: Integer
p = Integer
5Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^Int
i
      s :: Int
s = Integer -> Int
blen Integer
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bitcount
   in if Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Integer
p Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (-Int
s) else Integer
p Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
s

-- | Breaks each integer into two Word64s (lowBits, highBits)
splitWord128s :: [Integer] -> [Word64]
splitWord128s :: [Integer] -> [Word64]
splitWord128s [Integer]
li
  = [Integer -> Word64
forall a. Num a => Integer -> a
fromInteger Integer
w | Integer
x <- [Integer]
li, Integer
w <- [Integer
x Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
maxWord64, Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
64]]
  where  maxWord64 :: Integer
maxWord64 = Word64 -> Integer
forall a. Integral a => a -> Integer
toInteger (Word64
forall a. Bounded a => a
maxBound :: Word64)

splitWord128 :: Integer -> (Word64,Word64)
splitWord128 :: Integer -> (Word64, Word64)
splitWord128 Integer
x = (Integer -> Word64
forall a. Num a => Integer -> a
fromInteger (Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
64), Integer -> Word64
forall a. Num a => Integer -> a
fromInteger (Integer
x Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
maxWord64))
  where  maxWord64 :: Integer
maxWord64 = Word64 -> Integer
forall a. Integral a => a -> Integer
toInteger (Word64
forall a. Bounded a => a
maxBound :: Word64)


-- Helpers to generate case alternatives returning either one Word64 (case64) or
-- two Word64s (case128) for the PURE_HASKELL variant of the tables.
case64 :: (Int -> Integer) -> [Int] -> String
case64 :: (Int -> Integer) -> [Int] -> String
case64 Int -> Integer
f [Int]
range = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  [ Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> 0x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String -> String
forall a. Integral a => a -> String -> String
showHex (Int -> Integer
f Int
i) String
"\n"
  | Int
i <- [Int]
range]

case128 :: (Int -> Integer) -> [Int] -> String
case128 :: (Int -> Integer) -> [Int] -> String
case128 Int -> Integer
f [Int]
range = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  [ Int -> String
forall a. Show a => a -> String
show Int
i String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" -> (0x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word64 -> String -> String
forall a. Integral a => a -> String -> String
showHex Word64
hi String
"" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", 0x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word64 -> String -> String
forall a. Integral a => a -> String -> String
showHex Word64
lo String
")\n"
  | Int
i <- [Int]
range
  , let (Word64
hi,Word64
lo) = Integer -> (Word64, Word64)
splitWord128 (Int -> Integer
f Int
i)
  ]

-- Given a specific floating-point type, determine the range of q for the < 0
-- and >= 0 cases
get_range :: forall ff. (RealFloat ff) => ff -> (Int, Int)
get_range :: forall ff. RealFloat ff => ff -> (Int, Int)
get_range ff
f =
  let (Int
emin, Int
emax) = ff -> (Int, Int)
forall ff. RealFloat ff => ff -> (Int, Int)
floatRange ff
f
      mantissaDigits :: Int
mantissaDigits = ff -> Int
forall a. RealFloat a => a -> Int
floatDigits ff
f
      emin' :: Int
emin' = Int
emin Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mantissaDigits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2
      emax' :: Int
emax' = Int
emax Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mantissaDigits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2
   in ( (-Int
emin') Int -> Int -> Int
forall a. Num a => a -> a -> a
- Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Int -> Double
int2Double (-Int
emin') Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10 Double
5)
      , Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Int -> Double
int2Double Int
emax' Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10 Double
2))

float_max_split :: Int     -- = 46
float_max_inv_split :: Int -- = 30
(Int
float_max_split, Int
float_max_inv_split) = Float -> (Int, Int)
forall ff. RealFloat ff => ff -> (Int, Int)
get_range (Float
forall a. HasCallStack => a
undefined :: Float)

-- we take a slightly different codepath s.t we need one extra entry
double_max_split :: Int     -- = 325
double_max_inv_split :: Int -- = 291
(Int
double_max_split, Int
double_max_inv_split) =
    let (Int
m, Int
mi) = Double -> (Int, Int)
forall ff. RealFloat ff => ff -> (Int, Int)
get_range (Double
forall a. HasCallStack => a
undefined :: Double)
     in (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
mi)