{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE NoImplicitPrelude, ScopedTypeVariables #-}

--------------------------------------------------------------------------------
-- |
-- Module      :  Foreign.Marshal.Pool
-- Copyright   :  (c) Sven Panne 2002-2004
-- License     :  BSD-style (see the file libraries/base/LICENSE)
--
-- Maintainer  :  sven.panne@aedion.de
-- Stability   :  provisional
-- Portability :  portable
--
-- This module contains support for pooled memory management. Under this scheme,
-- (re-)allocations belong to a given pool, and everything in a pool is
-- deallocated when the pool itself is deallocated. This is useful when
-- 'Foreign.Marshal.Alloc.alloca' with its implicit allocation and deallocation
-- is not flexible enough, but explicit uses of 'Foreign.Marshal.Alloc.malloc'
-- and 'free' are too awkward.
--
--------------------------------------------------------------------------------

module Foreign.Marshal.Pool (
   -- * Pool management
   Pool,
   newPool,
   freePool,
   withPool,

   -- * (Re-)Allocation within a pool
   pooledMalloc,
   pooledMallocBytes,

   pooledRealloc,
   pooledReallocBytes,

   pooledMallocArray,
   pooledMallocArray0,

   pooledReallocArray,
   pooledReallocArray0,

   -- * Combined allocation and marshalling
   pooledNew,
   pooledNewArray,
   pooledNewArray0
) where

import GHC.Base              ( Int, Monad(..), (.), liftM, not )
import GHC.Err               ( undefined )
import GHC.Exception         ( throw )
import GHC.IO                ( IO, mask, catchAny )
import GHC.IORef             ( IORef, newIORef, readIORef, writeIORef )
import GHC.List              ( elem, length )
import GHC.Num               ( Num(..) )

import Data.OldList          ( delete )
import Foreign.Marshal.Alloc ( mallocBytes, reallocBytes, free )
import Foreign.Marshal.Array ( pokeArray, pokeArray0 )
import Foreign.Marshal.Error ( throwIf )
import Foreign.Ptr           ( Ptr, castPtr )
import Foreign.Storable      ( Storable(sizeOf, poke) )

--------------------------------------------------------------------------------

-- To avoid non-H2010 stuff like existentially quantified data constructors, we
-- simply use pointers to () below. Not very nice, but...

-- | A memory pool.

newtype Pool = Pool (IORef [Ptr ()])

-- | Allocate a fresh memory pool.

newPool :: IO Pool
newPool :: IO Pool
newPool = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM IORef [Ptr ()] -> Pool
Pool (forall a. a -> IO (IORef a)
newIORef [])

-- | Deallocate a memory pool and everything which has been allocated in the
-- pool itself.

freePool :: Pool -> IO ()
freePool :: Pool -> IO ()
freePool (Pool IORef [Ptr ()]
pool) = forall a. IORef a -> IO a
readIORef IORef [Ptr ()]
pool forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {a}. [Ptr a] -> IO ()
freeAll
   where freeAll :: [Ptr a] -> IO ()
freeAll []     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
         freeAll (Ptr a
p:[Ptr a]
ps) = forall a. Ptr a -> IO ()
free Ptr a
p forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Ptr a] -> IO ()
freeAll [Ptr a]
ps

-- | Execute an action with a fresh memory pool, which gets automatically
-- deallocated (including its contents) after the action has finished.

withPool :: (Pool -> IO b) -> IO b
withPool :: forall b. (Pool -> IO b) -> IO b
withPool Pool -> IO b
act =   -- ATTENTION: cut-n-paste from Control.Exception below!
   forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (\forall a. IO a -> IO a
restore -> do
      Pool
pool <- IO Pool
newPool
      b
val <- forall a. IO a -> (forall e. Exception e => e -> IO a) -> IO a
catchAny
                (forall a. IO a -> IO a
restore (Pool -> IO b
act Pool
pool))
                (\e
e -> do Pool -> IO ()
freePool Pool
pool; forall a e. Exception e => e -> a
throw e
e)
      Pool -> IO ()
freePool Pool
pool
      forall (m :: * -> *) a. Monad m => a -> m a
return b
val)

--------------------------------------------------------------------------------

-- | Allocate space for storable type in the given pool. The size of the area
-- allocated is determined by the 'sizeOf' method from the instance of
-- 'Storable' for the appropriate type.

pooledMalloc :: forall a . Storable a => Pool -> IO (Ptr a)
pooledMalloc :: forall a. Storable a => Pool -> IO (Ptr a)
pooledMalloc Pool
pool = forall a. Pool -> Int -> IO (Ptr a)
pooledMallocBytes Pool
pool (forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

-- | Allocate the given number of bytes of storage in the pool.

pooledMallocBytes :: Pool -> Int -> IO (Ptr a)
pooledMallocBytes :: forall a. Pool -> Int -> IO (Ptr a)
pooledMallocBytes (Pool IORef [Ptr ()]
pool) Int
size = do
   Ptr ()
ptr <- forall a. Int -> IO (Ptr a)
mallocBytes Int
size
   [Ptr ()]
ptrs <- forall a. IORef a -> IO a
readIORef IORef [Ptr ()]
pool
   forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr ()]
pool (Ptr ()
ptrforall a. a -> [a] -> [a]
:[Ptr ()]
ptrs)
   forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr)

-- | Adjust the storage area for an element in the pool to the given size of
-- the required type.

pooledRealloc :: forall a . Storable a => Pool -> Ptr a -> IO (Ptr a)
pooledRealloc :: forall a. Storable a => Pool -> Ptr a -> IO (Ptr a)
pooledRealloc Pool
pool Ptr a
ptr = forall a. Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes Pool
pool Ptr a
ptr (forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

-- | Adjust the storage area for an element in the pool to the given size.

pooledReallocBytes :: Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes :: forall a. Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes (Pool IORef [Ptr ()]
pool) Ptr a
ptr Int
size = do
   let cPtr :: Ptr b
cPtr = forall a b. Ptr a -> Ptr b
castPtr Ptr a
ptr
   [Ptr ()]
_ <- forall a. (a -> Bool) -> (a -> String) -> IO a -> IO a
throwIf (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall {b}. Ptr b
cPtr forall a. Eq a => a -> [a] -> Bool
`elem`)) (\[Ptr ()]
_ -> String
"pointer not in pool") (forall a. IORef a -> IO a
readIORef IORef [Ptr ()]
pool)
   Ptr ()
newPtr <- forall a. Ptr a -> Int -> IO (Ptr a)
reallocBytes forall {b}. Ptr b
cPtr Int
size
   [Ptr ()]
ptrs <- forall a. IORef a -> IO a
readIORef IORef [Ptr ()]
pool
   forall a. IORef a -> a -> IO ()
writeIORef IORef [Ptr ()]
pool (Ptr ()
newPtr forall a. a -> [a] -> [a]
: forall a. Eq a => a -> [a] -> [a]
delete forall {b}. Ptr b
cPtr [Ptr ()]
ptrs)
   forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
newPtr)

-- | Allocate storage for the given number of elements of a storable type in the
-- pool.

pooledMallocArray :: forall a . Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray :: forall a. Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray Pool
pool Int
size =
    forall a. Pool -> Int -> IO (Ptr a)
pooledMallocBytes Pool
pool (Int
size forall a. Num a => a -> a -> a
* forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

-- | Allocate storage for the given number of elements of a storable type in the
-- pool, but leave room for an extra element to signal the end of the array.

pooledMallocArray0 :: Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray0 :: forall a. Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray0 Pool
pool Int
size =
   forall a. Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray Pool
pool (Int
size forall a. Num a => a -> a -> a
+ Int
1)

-- | Adjust the size of an array in the given pool.

pooledReallocArray :: forall a . Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray :: forall a. Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray Pool
pool Ptr a
ptr Int
size =
    forall a. Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes Pool
pool Ptr a
ptr (Int
size forall a. Num a => a -> a -> a
* forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a))

-- | Adjust the size of an array with an end marker in the given pool.

pooledReallocArray0 :: Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray0 :: forall a. Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray0 Pool
pool Ptr a
ptr Int
size =
   forall a. Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray Pool
pool Ptr a
ptr (Int
size forall a. Num a => a -> a -> a
+ Int
1)

--------------------------------------------------------------------------------

-- | Allocate storage for a value in the given pool and marshal the value into
-- this storage.

pooledNew :: Storable a => Pool -> a -> IO (Ptr a)
pooledNew :: forall a. Storable a => Pool -> a -> IO (Ptr a)
pooledNew Pool
pool a
val = do
   Ptr a
ptr <- forall a. Storable a => Pool -> IO (Ptr a)
pooledMalloc Pool
pool
   forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
ptr a
val
   forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
ptr

-- | Allocate consecutive storage for a list of values in the given pool and
-- marshal these values into it.

pooledNewArray :: Storable a => Pool -> [a] -> IO (Ptr a)
pooledNewArray :: forall a. Storable a => Pool -> [a] -> IO (Ptr a)
pooledNewArray Pool
pool [a]
vals = do
   Ptr a
ptr <- forall a. Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray Pool
pool (forall a. [a] -> Int
length [a]
vals)
   forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr a
ptr [a]
vals
   forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
ptr

-- | Allocate consecutive storage for a list of values in the given pool and
-- marshal these values into it, terminating the end with the given marker.

pooledNewArray0 :: Storable a => Pool -> a -> [a] -> IO (Ptr a)
pooledNewArray0 :: forall a. Storable a => Pool -> a -> [a] -> IO (Ptr a)
pooledNewArray0 Pool
pool a
marker [a]
vals = do
   Ptr a
ptr <- forall a. Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray0 Pool
pool (forall a. [a] -> Int
length [a]
vals)
   forall a. Storable a => a -> Ptr a -> [a] -> IO ()
pokeArray0 a
marker Ptr a
ptr [a]
vals
   forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
ptr