{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module System.Semaphore
  ( -- * System semaphores
    Semaphore(..), SemaphoreName(..)
  , createSemaphore, freshSemaphore, openSemaphore
  , waitOnSemaphore, tryWaitOnSemaphore
  , WaitId(..)
  , forkWaitOnSemaphoreInterruptible
  , interruptWaitOnSemaphore
  , getSemaphoreValue
  , releaseSemaphore
  , destroySemaphore

  -- * Abstract semaphores
  , AbstractSem(..)
  , withAbstractSem
  ) where

-- base
import Control.Concurrent
import Control.Monad
import Data.List.NonEmpty ( NonEmpty(..) )
import GHC.Exts ( Char(..), Int(..), indexCharOffAddr# )

-- exceptions
import qualified Control.Monad.Catch as MC

#if defined(mingw32_HOST_OS)
-- Win32
import qualified System.Win32.Event     as Win32
  ( createEvent, setEvent
  , waitForSingleObject, waitForMultipleObjects
  , wAIT_OBJECT_0 )
import qualified System.Win32.File      as Win32
  ( closeHandle )
import qualified System.Win32.Process   as Win32
  ( iNFINITE )
import qualified System.Win32.Semaphore as Win32
  ( Semaphore(..), sEMAPHORE_ALL_ACCESS
  , createSemaphore, openSemaphore, releaseSemaphore )
import qualified System.Win32.Time      as Win32
  ( FILETIME(..), getSystemTimeAsFileTime )
import qualified System.Win32.Types     as Win32
  ( HANDLE, errorWin )
-- base
import Foreign.C.Types
  ( CClock(..) )

-- unix
import qualified System.Posix.Semaphore as Posix
  ( Semaphore, OpenSemFlags(..)
  , semOpen, semWaitInterruptible, semTryWait, semThreadWait
  , semGetValue, semPost, semUnlink )
import qualified System.Posix.Files     as Posix
  ( stdFileMode )
import qualified System.Posix.Process   as Posix
  ( ProcessTimes(systemTime), getProcessTimes )

-- System-specific semaphores

newtype SemaphoreName =
  SemaphoreName { SemaphoreName -> String
getSemaphoreName :: String }
  deriving SemaphoreName -> SemaphoreName -> Bool
(SemaphoreName -> SemaphoreName -> Bool)
-> (SemaphoreName -> SemaphoreName -> Bool) -> Eq SemaphoreName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SemaphoreName -> SemaphoreName -> Bool
== :: SemaphoreName -> SemaphoreName -> Bool
$c/= :: SemaphoreName -> SemaphoreName -> Bool
/= :: SemaphoreName -> SemaphoreName -> Bool

-- | A system semaphore (POSIX or Win32).
data Semaphore =
    { Semaphore -> SemaphoreName
semaphoreName :: !SemaphoreName
    , Semaphore -> Semaphore
semaphore     ::
#if defined(mingw32_HOST_OS)

-- | Create a new semaphore with the given name and initial amount of
-- available resources.
-- Throws an error if a semaphore by this name already exists.
createSemaphore :: SemaphoreName
                -> Int -- ^ number of tokens on the semaphore
                -> IO Semaphore
createSemaphore :: SemaphoreName -> Int -> IO Semaphore
createSemaphore (SemaphoreName String
sem_name) Int
init_toks = do
  Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_name Int
  case Either (IO Semaphore) Semaphore
mb_sem of
    Left  IO Semaphore
err -> IO Semaphore
    Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore

-- | Create a fresh semaphore with the given amount of tokens.
-- Its name will start with the given prefix, but will have a random suffix
-- appended to it.
freshSemaphore :: String -- ^ prefix
               -> Int    -- ^ number of tokens on the semaphore
               -> IO Semaphore
freshSemaphore :: String -> Int -> IO Semaphore
freshSemaphore String
prefix Int
init_toks = do
  NonEmpty String
suffixes <- IO (NonEmpty String)
  Int -> NonEmpty String -> IO Semaphore
go Int
0 NonEmpty String
    go :: Int -> NonEmpty String -> IO Semaphore
    go :: Int -> NonEmpty String -> IO Semaphore
go Int
i (String
suffix :| [String]
suffs) = do
      Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem (String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
suffix) Int
      case Either (IO Semaphore) Semaphore
mb_sem of
        Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
        Left  IO Semaphore
          | String
next : [String]
nexts <- [String]
          , Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
32 -- give up after 32 attempts
          -> Int -> NonEmpty String -> IO Semaphore
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
1) (String
next String -> [String] -> NonEmpty String
forall a. a -> [a] -> NonEmpty a
:| [String]
          | Bool
          -> IO Semaphore

create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_str Int
init_toks = do
#if defined(mingw32_HOST_OS)
  let toks = fromIntegral init_toks
  mb_sem <- MC.try @_ @MC.SomeException $
    Win32.createSemaphore Nothing toks toks (Just sem_str)
  return $ case mb_sem of
    Right (sem, exists)
      | exists
      -> Left (Win32.errorWin $ "semaphore-compat: semaphore " ++ sem_str ++ " already exists")
      | otherwise
      -> Right $ mk_sem sem
    Left err
      -> Left $ MC.throwM err
  let flags :: OpenSemFlags
flags =
          { semCreate :: Bool
Posix.semCreate    = Bool
          , semExclusive :: Bool
Posix.semExclusive = Bool
True }
  Either SomeException Semaphore
mb_sem <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @MC.SomeException (IO Semaphore -> IO (Either SomeException Semaphore))
-> IO Semaphore -> IO (Either SomeException Semaphore)
forall a b. (a -> b) -> a -> b
    String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_str OpenSemFlags
flags FileMode
Posix.stdFileMode Int
  Either (IO Semaphore) Semaphore
-> IO (Either (IO Semaphore) Semaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (IO Semaphore) Semaphore
 -> IO (Either (IO Semaphore) Semaphore))
-> Either (IO Semaphore) Semaphore
-> IO (Either (IO Semaphore) Semaphore)
forall a b. (a -> b) -> a -> b
$ case Either SomeException Semaphore
mb_sem of
    Left  SomeException
err -> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. a -> Either a b
Left (IO Semaphore -> Either (IO Semaphore) Semaphore)
-> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ SomeException -> IO Semaphore
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
MC.throwM SomeException
    Right Semaphore
sem -> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. b -> Either a b
Right (Semaphore -> Either (IO Semaphore) Semaphore)
-> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ Semaphore -> Semaphore
mk_sem Semaphore
    sem_nm :: SemaphoreName
sem_nm = String -> SemaphoreName
SemaphoreName String
    mk_sem :: Semaphore -> Semaphore
mk_sem Semaphore
sem =
        { semaphore :: Semaphore
semaphore     = Semaphore
        , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
sem_nm }

-- | Open a semaphore with the given name.
-- If no such semaphore exists, throws an error.
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore nm :: SemaphoreName
nm@(SemaphoreName String
sem_name) = do
#if defined(mingw32_HOST_OS)
  sem <- Win32.openSemaphore Win32.sEMAPHORE_ALL_ACCESS True sem_name
    flags :: OpenSemFlags
flags = Posix.OpenSemFlags
          { semCreate :: Bool
Posix.semCreate    = Bool
          , semExclusive :: Bool
Posix.semExclusive = Bool
False }
sem <- String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_name OpenSemFlags
flags FileMode
Posix.stdFileMode Int
  Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Semaphore -> IO Semaphore) -> Semaphore -> IO Semaphore
forall a b. (a -> b) -> a -> b
      { semaphore :: Semaphore
semaphore     = Semaphore
      , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
nm }

-- | Indefinitely wait on a semaphore.
-- If you want to be able to cancel a wait operation, use
-- 'forkWaitOnSemaphoreInterruptible' instead.
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    () <$ Win32.waitForSingleObject (Win32.semaphoreHandle sem) Win32.iNFINITE
  Semaphore -> IO ()
Posix.semThreadWait Semaphore

-- | Try to obtain a token from the semaphore, without blocking.
-- Immediately returns 'False' if no resources are available.
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    return $ wait_res == Win32.wAIT_OBJECT_0
  Semaphore -> IO Bool
Posix.semTryWait Semaphore

-- | Release a semaphore: add @n@ to its internal counter.
-- No-op when `n <= 0`.
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) Int
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
  = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
  = IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
#if defined(mingw32_HOST_OS)
    void $ Win32.releaseSemaphore sem (fromIntegral n)
    Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
n (Semaphore -> IO ()
Posix.semPost Semaphore

-- | Destroy the given semaphore.
destroySemaphore :: Semaphore -> IO ()
destroySemaphore :: Semaphore -> IO ()
destroySemaphore Semaphore
sem =
#if defined(mingw32_HOST_OS)
  Win32.closeHandle (Win32.semaphoreHandle $ semaphore sem)
  String -> IO ()
Posix.semUnlink (SemaphoreName -> String
getSemaphoreName (SemaphoreName -> String) -> SemaphoreName -> String
forall a b. (a -> b) -> a -> b
$ Semaphore -> SemaphoreName
semaphoreName Semaphore

-- | Query the current semaphore value (how many tokens it has available).
-- This is mainly for debugging use, as it is easy to introduce race conditions
-- when nontrivial program logic depends on the value returned by this function.
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    if wait_res == Win32.wAIT_OBJECT_0
      -- We were able to acquire a resource from the semaphore without waiting:
      -- release it immediately, thus obtaining the total number of available
      -- resources.
      (+1) . fromIntegral <$> Win32.releaseSemaphore sem 1
      return 0
  Semaphore -> IO Int
Posix.semGetValue Semaphore

-- | 'WaitId' stores the information we need to cancel a thread
-- which is waiting on a semaphore.
-- See 'forkWaitOnSemaphoreInterruptible' and 'interruptWaitOnSemaphore'.
data WaitId = WaitId { WaitId -> ThreadId
waitingThreadId :: ThreadId
#if defined(mingw32_HOST_OS)
                     , cancelHandle    :: Win32.HANDLE

-- | Spawn a thread that waits on the given semaphore.
-- In this thread, asynchronous exceptions will be masked.
-- The waiting operation can be interrupted using the
-- 'interruptWaitOnSemaphore' function.
  :: Semaphore
  -> ( Either MC.SomeException Bool -> IO () ) -- ^ wait result action
  -> IO WaitId
forkWaitOnSemaphoreInterruptible :: Semaphore -> (Either SomeException Bool -> IO ()) -> IO WaitId
  (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem })
  Either SomeException Bool -> IO ()
wait_result_action = do
#if defined(mingw32_HOST_OS)
    cancelHandle <- Win32.createEvent Nothing True False ""
      interruptible_wait :: IO Bool
      interruptible_wait :: IO Bool
interruptible_wait =
#if defined(mingw32_HOST_OS)
        -- Windows: wait on both the handle used for cancelling the wait
        -- and on the semaphore.
            wait_res <-
                [ Win32.semaphoreHandle sem
                , cancelHandle ]
                False -- False <=> WaitAny
            return $ wait_res == Win32.wAIT_OBJECT_0
            -- Only in the case that the wait result is WAIT_OBJECT_0 will
            -- we have succeeded in obtaining a token from the semaphore.
        -- POSIX: use the 'semWaitInterruptible' interruptible FFI call
        -- that can be interrupted when we send a killThread signal.
          Semaphore -> IO Bool
Posix.semWaitInterruptible Semaphore
waitingThreadId <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Either SomeException Bool
wait_res <- IO Bool -> IO (Either SomeException Bool)
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try IO Bool
      Either SomeException Bool -> IO ()
wait_result_action Either SomeException Bool
    WaitId -> IO WaitId
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (WaitId -> IO WaitId) -> WaitId -> IO WaitId
forall a b. (a -> b) -> a -> b
$ WaitId { ThreadId
waitingThreadId :: ThreadId
waitingThreadId :: ThreadId
.. }

-- | Interrupt a semaphore wait operation initiated by
-- 'forkWaitOnSemaphoreInterruptible'.
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore ( WaitId { ThreadId
waitingThreadId :: WaitId -> ThreadId
waitingThreadId :: ThreadId
.. } ) = do
#if defined(mingw32_HOST_OS)
  Win32.setEvent cancelHandle
    -- On Windows, we signal to stop waiting.
  ThreadId -> IO ()
killThread ThreadId
    -- On POSIX, killing the thread will cancel the wait on the semaphore
    -- due to the FFI call being interruptible ('semWaitInterruptible').

-- Abstract semaphores

-- | Abstraction over the operations of a semaphore.
data AbstractSem =
    { AbstractSem -> IO ()
acquireSem :: IO ()
    , AbstractSem -> IO ()
releaseSem :: IO ()

withAbstractSem :: AbstractSem -> IO b -> IO b
withAbstractSem :: forall b. AbstractSem -> IO b -> IO b
withAbstractSem AbstractSem
sem = IO () -> IO () -> IO b -> IO b
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
MC.bracket_ (AbstractSem -> IO ()
acquireSem AbstractSem
sem) (AbstractSem -> IO ()
releaseSem AbstractSem

-- Utility

iToBase62 :: Int -> String
iToBase62 :: Int -> String
iToBase62 Int
m = Int -> String -> String
go Int
m' String
    m' :: Int
      | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
forall a. Bounded a => a
      = Int
forall a. Bounded a => a
      | Bool
      = Int -> Int
forall a. Num a => a -> a
abs Int
    go :: Int -> String -> String
go Int
n String
cs | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
            = let !c :: Char
c = Int -> Char
chooseChar62 Int
              in Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
            | Bool
            = let !(!Int
q, Int
r) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
                  !c :: Char
c       = Int -> Char
chooseChar62 Int
              in Int -> String -> String
go Int
q (Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String

    chooseChar62 :: Int -> Char
    {-# INLINE chooseChar62 #-}
    chooseChar62 :: Int -> Char
chooseChar62 (I# Int#
n) = Char# -> Char
C# (Addr# -> Int# -> Char#
indexCharOffAddr# Addr#
chars62 Int#
    chars62 :: Addr#
chars62 = Addr#

random_strings :: IO (NonEmpty String)
random_strings :: IO (NonEmpty String)
random_strings = do
#if defined(mingw32_HOST_OS)
  Win32.FILETIME t <- Win32.getSystemTimeAsFileTime
  CClock Int64
t <- ProcessTimes -> CClock
Posix.systemTime (ProcessTimes -> CClock) -> IO ProcessTimes -> IO CClock
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ProcessTimes
  NonEmpty String -> IO (NonEmpty String)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty String -> IO (NonEmpty String))
-> NonEmpty String -> IO (NonEmpty String)
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> NonEmpty Int -> NonEmpty String
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ( \ Int
i -> Int -> String
iToBase62 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
t) ) (Int
0 Int -> [Int] -> NonEmpty Int
forall a. a -> [a] -> NonEmpty a
:| [Int