{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module System.Semaphore
  ( -- * System semaphores
    Semaphore(..), SemaphoreName(..)
  , createSemaphore, freshSemaphore, openSemaphore
  , waitOnSemaphore, tryWaitOnSemaphore
  , WaitId(..)
  , forkWaitOnSemaphoreInterruptible
  , interruptWaitOnSemaphore
  , getSemaphoreValue
  , releaseSemaphore
  , destroySemaphore

  -- * Abstract semaphores
  , AbstractSem(..)
  , withAbstractSem
  ) where

-- base
import Control.Concurrent
import Control.Monad
import Data.List.NonEmpty ( NonEmpty(..) )
import GHC.Exts ( Char(..), Int(..), indexCharOffAddr# )

-- exceptions
import qualified Control.Monad.Catch as MC

#if defined(mingw32_HOST_OS)
-- Win32
import qualified System.Win32.Event     as Win32
  ( createEvent, setEvent
  , waitForSingleObject, waitForMultipleObjects
  , wAIT_OBJECT_0 )
import qualified System.Win32.File      as Win32
  ( closeHandle )
import qualified System.Win32.Process   as Win32
  ( iNFINITE )
import qualified System.Win32.Semaphore as Win32
  ( Semaphore(..), sEMAPHORE_ALL_ACCESS
  , createSemaphore, openSemaphore, releaseSemaphore )
import qualified System.Win32.Time      as Win32
  ( FILETIME(..), getSystemTimeAsFileTime )
import qualified System.Win32.Types     as Win32
  ( HANDLE, errorWin )
#else
-- base
import Foreign.C.Types
  ( CClock(..) )

-- unix
import qualified System.Posix.Semaphore as Posix
  ( Semaphore, OpenSemFlags(..)
  , semOpen, semWaitInterruptible, semTryWait, semThreadWait
  , semGetValue, semPost, semUnlink )
import qualified System.Posix.Files     as Posix
  ( stdFileMode )
import qualified System.Posix.Process   as Posix
  ( ProcessTimes(systemTime), getProcessTimes )
#endif

---------------------------------------
-- System-specific semaphores

newtype SemaphoreName =
  SemaphoreName { SemaphoreName -> String
getSemaphoreName :: String }
  deriving SemaphoreName -> SemaphoreName -> Bool
(SemaphoreName -> SemaphoreName -> Bool)
-> (SemaphoreName -> SemaphoreName -> Bool) -> Eq SemaphoreName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SemaphoreName -> SemaphoreName -> Bool
== :: SemaphoreName -> SemaphoreName -> Bool
$c/= :: SemaphoreName -> SemaphoreName -> Bool
/= :: SemaphoreName -> SemaphoreName -> Bool
Eq

-- | A system semaphore (POSIX or Win32).
data Semaphore =
  Semaphore
    { Semaphore -> SemaphoreName
semaphoreName :: !SemaphoreName
    , Semaphore -> Semaphore
semaphore     ::
#if defined(mingw32_HOST_OS)
      !Win32.Semaphore
#else
      !Posix.Semaphore
#endif
    }

-- | Create a new semaphore with the given name and initial amount of
-- available resources.
--
-- Throws an error if a semaphore by this name already exists.
createSemaphore :: SemaphoreName
                -> Int -- ^ number of tokens on the semaphore
                -> IO Semaphore
createSemaphore :: SemaphoreName -> Int -> IO Semaphore
createSemaphore (SemaphoreName String
sem_name) Int
init_toks = do
  mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_name Int
init_toks
  case mb_sem of
    Left  IO Semaphore
err -> IO Semaphore
err
    Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem

-- | Create a fresh semaphore with the given amount of tokens.
--
-- Its name will start with the given prefix, but will have a random suffix
-- appended to it.
freshSemaphore :: String -- ^ prefix
               -> Int    -- ^ number of tokens on the semaphore
               -> IO Semaphore
freshSemaphore :: String -> Int -> IO Semaphore
freshSemaphore String
prefix Int
init_toks = do
  suffixes <- IO (NonEmpty String)
random_strings
  go 0 suffixes
  where
    go :: Int -> NonEmpty String -> IO Semaphore
    go :: Int -> NonEmpty String -> IO Semaphore
go Int
i (String
suffix :| [String]
suffs) = do
      mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem (String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
suffix) Int
init_toks
      case mb_sem of
        Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem
        Left  IO Semaphore
err
          | String
next : [String]
nexts <- [String]
suffs
          , Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
32 -- give up after 32 attempts
          -> Int -> NonEmpty String -> IO Semaphore
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (String
next String -> [String] -> NonEmpty String
forall a. a -> [a] -> NonEmpty a
:| [String]
nexts)
          | Bool
otherwise
          -> IO Semaphore
err

create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_str Int
init_toks = do
#if defined(mingw32_HOST_OS)
  let toks :: LONG
toks = Int -> LONG
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
init_toks
  mb_sem <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @MC.SomeException (IO (Semaphore, Bool)
 -> IO (Either SomeException (Semaphore, Bool)))
-> IO (Semaphore, Bool)
-> IO (Either SomeException (Semaphore, Bool))
forall a b. (a -> b) -> a -> b
$
    Maybe SECURITY_ATTRIBUTES
-> LONG -> LONG -> Maybe String -> IO (Semaphore, Bool)
Win32.createSemaphore Maybe SECURITY_ATTRIBUTES
forall a. Maybe a
Nothing LONG
toks LONG
toks (String -> Maybe String
forall a. a -> Maybe a
Just String
sem_str)
  return $ case mb_sem of
    Right (Semaphore
sem, Bool
exists)
      | Bool
exists
      -> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. a -> Either a b
Left (String -> IO Semaphore
forall a. String -> IO a
Win32.errorWin (String -> IO Semaphore) -> String -> IO Semaphore
forall a b. (a -> b) -> a -> b
$ String
"semaphore-compat: semaphore " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sem_str String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" already exists")
      | Bool
otherwise
      -> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. b -> Either a b
Right (Semaphore -> Either (IO Semaphore) Semaphore)
-> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ Semaphore -> Semaphore
mk_sem Semaphore
sem
    Left SomeException
err
      -> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. a -> Either a b
Left (IO Semaphore -> Either (IO Semaphore) Semaphore)
-> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ SomeException -> IO Semaphore
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
MC.throwM SomeException
err
#else
  let flags =
        Posix.OpenSemFlags
          { Posix.semCreate    = True
          , Posix.semExclusive = True }
  mb_sem <- MC.try @_ @MC.SomeException $
    Posix.semOpen sem_str flags Posix.stdFileMode init_toks
  return $ case mb_sem of
    Left  err -> Left $ MC.throwM err
    Right sem -> Right $ mk_sem sem
#endif
  where
    sem_nm :: SemaphoreName
sem_nm = String -> SemaphoreName
SemaphoreName String
sem_str
    mk_sem :: Semaphore -> Semaphore
mk_sem Semaphore
sem =
      Semaphore
        { semaphore :: Semaphore
semaphore     = Semaphore
sem
        , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
sem_nm }

-- | Open a semaphore with the given name.
--
-- If no such semaphore exists, throws an error.
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore nm :: SemaphoreName
nm@(SemaphoreName String
sem_name) = do
#if defined(mingw32_HOST_OS)
  sem <- Word32 -> Bool -> String -> IO Semaphore
Win32.openSemaphore Word32
Win32.sEMAPHORE_ALL_ACCESS Bool
True String
sem_name
#else
  let
    flags = Posix.OpenSemFlags
          { Posix.semCreate    = False
          , Posix.semExclusive = False }
  sem <- Posix.semOpen sem_name flags Posix.stdFileMode 0
#endif
  return $
    Semaphore
      { semaphore     = sem
      , semaphoreName = nm }

-- | Indefinitely wait on a semaphore.
--
-- If you want to be able to cancel a wait operation, use
-- 'forkWaitOnSemaphoreInterruptible' instead.
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    () () -> IO Word32 -> IO ()
forall a b. a -> IO b -> IO a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ HANDLE -> Word32 -> IO Word32
Win32.waitForSingleObject (Semaphore -> HANDLE
Win32.semaphoreHandle Semaphore
sem) Word32
Win32.iNFINITE
#else
  Posix.semThreadWait sem
#endif

-- | Try to obtain a token from the semaphore, without blocking.
--
-- Immediately returns 'False' if no resources are available.
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  IO Bool -> IO Bool
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
    wait_res <- HANDLE -> Word32 -> IO Word32
Win32.waitForSingleObject (Semaphore -> HANDLE
Win32.semaphoreHandle Semaphore
sem) Word32
0
    return $ wait_res == Win32.wAIT_OBJECT_0
#else
  Posix.semTryWait sem
#endif

-- | Release a semaphore: add @n@ to its internal counter.
--
-- No-op when `n <= 0`.
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
  = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise
  = IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
#if defined(mingw32_HOST_OS)
    IO LONG -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO LONG -> IO ()) -> IO LONG -> IO ()
forall a b. (a -> b) -> a -> b
$ Semaphore -> LONG -> IO LONG
Win32.releaseSemaphore Semaphore
sem (Int -> LONG
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
#else
    replicateM_ n (Posix.semPost sem)
#endif

-- | Destroy the given semaphore.
destroySemaphore :: Semaphore -> IO ()
destroySemaphore :: Semaphore -> IO ()
destroySemaphore Semaphore
sem =
#if defined(mingw32_HOST_OS)
  HANDLE -> IO ()
Win32.closeHandle (Semaphore -> HANDLE
Win32.semaphoreHandle (Semaphore -> HANDLE) -> Semaphore -> HANDLE
forall a b. (a -> b) -> a -> b
$ Semaphore -> Semaphore
semaphore Semaphore
sem)
#else
  Posix.semUnlink (getSemaphoreName $ semaphoreName sem)
#endif

-- | Query the current semaphore value (how many tokens it has available).
--
-- This is mainly for debugging use, as it is easy to introduce race conditions
-- when nontrivial program logic depends on the value returned by this function.
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  IO Int -> IO Int
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO Int -> IO Int) -> IO Int -> IO Int
forall a b. (a -> b) -> a -> b
$ do
    wait_res <- HANDLE -> Word32 -> IO Word32
Win32.waitForSingleObject (Semaphore -> HANDLE
Win32.semaphoreHandle Semaphore
sem) Word32
0
    if wait_res == Win32.wAIT_OBJECT_0
      -- We were able to acquire a resource from the semaphore without waiting:
      -- release it immediately, thus obtaining the total number of available
      -- resources.
    then
      (+1) . fromIntegral <$> Win32.releaseSemaphore sem 1
    else
      return 0
#else
  Posix.semGetValue sem
#endif

-- | 'WaitId' stores the information we need to cancel a thread
-- which is waiting on a semaphore.
--
-- See 'forkWaitOnSemaphoreInterruptible' and 'interruptWaitOnSemaphore'.
data WaitId = WaitId { WaitId -> ThreadId
waitingThreadId :: ThreadId
#if defined(mingw32_HOST_OS)
                     , WaitId -> HANDLE
cancelHandle    :: Win32.HANDLE
#endif
                     }

-- | Spawn a thread that waits on the given semaphore.
--
-- In this thread, asynchronous exceptions will be masked.
--
-- The waiting operation can be interrupted using the
-- 'interruptWaitOnSemaphore' function.
--
-- This implements a similar pattern to the @forkFinally@ function.
forkWaitOnSemaphoreInterruptible
  :: Semaphore
  -> ( Either MC.SomeException Bool -> IO () ) -- ^ wait result action
  -> IO WaitId
forkWaitOnSemaphoreInterruptible :: Semaphore -> (Either SomeException Bool -> IO ()) -> IO WaitId
forkWaitOnSemaphoreInterruptible
  (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem })
  Either SomeException Bool -> IO ()
wait_result_action = do
#if defined(mingw32_HOST_OS)
    cancelHandle <- Maybe SECURITY_ATTRIBUTES -> Bool -> Bool -> String -> IO HANDLE
Win32.createEvent Maybe SECURITY_ATTRIBUTES
forall a. Maybe a
Nothing Bool
True Bool
False String
""
#endif
    let
      interruptible_wait :: IO Bool
      interruptible_wait =
#if defined(mingw32_HOST_OS)
        -- Windows: wait on both the handle used for cancelling the wait
        -- and on the semaphore.
          do
            wait_res <-
              [HANDLE] -> Bool -> Word32 -> IO Word32
Win32.waitForMultipleObjects
                [ Semaphore -> HANDLE
Win32.semaphoreHandle Semaphore
sem
                , HANDLE
cancelHandle ]
                Bool
False -- False <=> WaitAny
                Word32
Win32.iNFINITE
            return $ wait_res == Win32.wAIT_OBJECT_0
            -- Only in the case that the wait result is WAIT_OBJECT_0 will
            -- we have succeeded in obtaining a token from the semaphore.
#else
        -- POSIX: use the 'semWaitInterruptible' interruptible FFI call
        -- that can be interrupted when we send a killThread signal.
          Posix.semWaitInterruptible sem
#endif
    waitingThreadId <- forkIO $ MC.mask_ $ do
      wait_res <- MC.try interruptible_wait
      wait_result_action wait_res
    return $ WaitId { .. }

-- | Interrupt a semaphore wait operation initiated by
-- 'forkWaitOnSemaphoreInterruptible'.
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore ( WaitId { HANDLE
ThreadId
waitingThreadId :: WaitId -> ThreadId
cancelHandle :: WaitId -> HANDLE
waitingThreadId :: ThreadId
cancelHandle :: HANDLE
.. } ) = do
#if defined(mingw32_HOST_OS)
  HANDLE -> IO ()
Win32.setEvent HANDLE
cancelHandle
    -- On Windows, we signal to stop waiting.
#endif
  ThreadId -> IO ()
killThread ThreadId
waitingThreadId
    -- On POSIX, killing the thread will cancel the wait on the semaphore
    -- due to the FFI call being interruptible ('semWaitInterruptible').

---------------------------------------
-- Abstract semaphores

-- | Abstraction over the operations of a semaphore.
data AbstractSem =
  AbstractSem
    { AbstractSem -> IO ()
acquireSem :: IO ()
    , AbstractSem -> IO ()
releaseSem :: IO ()
    }

withAbstractSem :: AbstractSem -> IO b -> IO b
withAbstractSem :: forall b. AbstractSem -> IO b -> IO b
withAbstractSem AbstractSem
sem = IO () -> IO () -> IO b -> IO b
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
MC.bracket_ (AbstractSem -> IO ()
acquireSem AbstractSem
sem) (AbstractSem -> IO ()
releaseSem AbstractSem
sem)

---------------------------------------
-- Utility

iToBase62 :: Int -> String
iToBase62 :: Int -> String
iToBase62 Int
m = Int -> String -> String
go Int
m' String
""
  where
    m' :: Int
m'
      | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
forall a. Bounded a => a
minBound
      = Int
forall a. Bounded a => a
maxBound
      | Bool
otherwise
      = Int -> Int
forall a. Num a => a -> a
abs Int
m
    go :: Int -> String -> String
go Int
n String
cs | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
62
            = let !c :: Char
c = Int -> Char
chooseChar62 Int
n
              in Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
cs
            | Bool
otherwise
            = let !(!Int
q, Int
r) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
62
                  !c :: Char
c       = Int -> Char
chooseChar62 Int
r
              in Int -> String -> String
go Int
q (Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
cs)

    chooseChar62 :: Int -> Char
    {-# INLINE chooseChar62 #-}
    chooseChar62 :: Int -> Char
chooseChar62 (I# Int#
n) = Char# -> Char
C# (Addr# -> Int# -> Char#
indexCharOffAddr# Addr#
chars62 Int#
n)
    chars62 :: Addr#
chars62 = Addr#
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"#

random_strings :: IO (NonEmpty String)
random_strings :: IO (NonEmpty String)
random_strings = do
#if defined(mingw32_HOST_OS)
  Win32.FILETIME t <- IO FILETIME
Win32.getSystemTimeAsFileTime
#else
  CClock t <- Posix.systemTime <$> Posix.getProcessTimes
#endif
  return $ fmap ( \ Int
i -> Int -> String
iToBase62 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
t) ) (0 :| [1..])