{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module System.Semaphore
  ( -- * System semaphores
    Semaphore(..), SemaphoreName(..)
  , createSemaphore, freshSemaphore, openSemaphore
  , waitOnSemaphore, tryWaitOnSemaphore
  , WaitId(..)
  , forkWaitOnSemaphoreInterruptible
  , interruptWaitOnSemaphore
  , getSemaphoreValue
  , releaseSemaphore
  , destroySemaphore

  -- * Abstract semaphores
  , AbstractSem(..)
  , withAbstractSem
  ) where

-- base
import Control.Concurrent
import Control.Monad
import Data.List.NonEmpty ( NonEmpty(..) )
import GHC.Exts ( Char(..), Int(..), indexCharOffAddr# )

-- exceptions
import qualified Control.Monad.Catch as MC

#if defined(mingw32_HOST_OS)
-- Win32
import qualified System.Win32.Event     as Win32
  ( createEvent, setEvent
  , waitForSingleObject, waitForMultipleObjects
  , wAIT_OBJECT_0 )
import qualified System.Win32.File      as Win32
  ( closeHandle )
import qualified System.Win32.Process   as Win32
  ( iNFINITE )
import qualified System.Win32.Semaphore as Win32
  ( Semaphore(..), sEMAPHORE_ALL_ACCESS
  , createSemaphore, openSemaphore, releaseSemaphore )
import qualified System.Win32.Time      as Win32
  ( FILETIME(..), getSystemTimeAsFileTime )
import qualified System.Win32.Types     as Win32
  ( HANDLE, errorWin )
#else
-- base
import Foreign.C.Types
  ( CClock(..) )

-- unix
import qualified System.Posix.Semaphore as Posix
  ( Semaphore, OpenSemFlags(..)
  , semOpen, semWaitInterruptible, semTryWait, semThreadWait
  , semGetValue, semPost, semUnlink )
import qualified System.Posix.Files     as Posix
  ( stdFileMode )
import qualified System.Posix.Process   as Posix
  ( ProcessTimes(systemTime), getProcessTimes )
#endif

---------------------------------------
-- System-specific semaphores

newtype SemaphoreName =
  SemaphoreName { SemaphoreName -> String
getSemaphoreName :: String }
  deriving SemaphoreName -> SemaphoreName -> Bool
(SemaphoreName -> SemaphoreName -> Bool)
-> (SemaphoreName -> SemaphoreName -> Bool) -> Eq SemaphoreName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SemaphoreName -> SemaphoreName -> Bool
== :: SemaphoreName -> SemaphoreName -> Bool
$c/= :: SemaphoreName -> SemaphoreName -> Bool
/= :: SemaphoreName -> SemaphoreName -> Bool
Eq

-- | A system semaphore (POSIX or Win32).
data Semaphore =
  Semaphore
    { Semaphore -> SemaphoreName
semaphoreName :: !SemaphoreName
    , Semaphore -> Semaphore
semaphore     ::
#if defined(mingw32_HOST_OS)
      !Win32.Semaphore
#else
      !Posix.Semaphore
#endif
    }

-- | Create a new semaphore with the given name and initial amount of
-- available resources.
--
-- Throws an error if a semaphore by this name already exists.
createSemaphore :: SemaphoreName
                -> Int -- ^ number of tokens on the semaphore
                -> IO Semaphore
createSemaphore :: SemaphoreName -> Int -> IO Semaphore
createSemaphore (SemaphoreName String
sem_name) Int
init_toks = do
  Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_name Int
init_toks
  case Either (IO Semaphore) Semaphore
mb_sem of
    Left  IO Semaphore
err -> IO Semaphore
err
    Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem

-- | Create a fresh semaphore with the given amount of tokens.
--
-- Its name will start with the given prefix, but will have a random suffix
-- appended to it.
freshSemaphore :: String -- ^ prefix
               -> Int    -- ^ number of tokens on the semaphore
               -> IO Semaphore
freshSemaphore :: String -> Int -> IO Semaphore
freshSemaphore String
prefix Int
init_toks = do
  NonEmpty String
suffixes <- IO (NonEmpty String)
random_strings
  Int -> NonEmpty String -> IO Semaphore
go Int
0 NonEmpty String
suffixes
  where
    go :: Int -> NonEmpty String -> IO Semaphore
    go :: Int -> NonEmpty String -> IO Semaphore
go Int
i (String
suffix :| [String]
suffs) = do
      Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem (String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
suffix) Int
init_toks
      case Either (IO Semaphore) Semaphore
mb_sem of
        Right Semaphore
sem -> Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem
        Left  IO Semaphore
err
          | String
next : [String]
nexts <- [String]
suffs
          , Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
32 -- give up after 32 attempts
          -> Int -> NonEmpty String -> IO Semaphore
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (String
next String -> [String] -> NonEmpty String
forall a. a -> [a] -> NonEmpty a
:| [String]
nexts)
          | Bool
otherwise
          -> IO Semaphore
err

create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_str Int
init_toks = do
#if defined(mingw32_HOST_OS)
  let toks = fromIntegral init_toks
  mb_sem <- MC.try @_ @MC.SomeException $
    Win32.createSemaphore Nothing toks toks (Just sem_str)
  return $ case mb_sem of
    Right (sem, exists)
      | exists
      -> Left (Win32.errorWin $ "semaphore-compat: semaphore " ++ sem_str ++ " already exists")
      | otherwise
      -> Right $ mk_sem sem
    Left err
      -> Left $ MC.throwM err
#else
  let flags :: OpenSemFlags
flags =
        Posix.OpenSemFlags
          { semCreate :: Bool
Posix.semCreate    = Bool
True
          , semExclusive :: Bool
Posix.semExclusive = Bool
True }
  Either SomeException Semaphore
mb_sem <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @MC.SomeException (IO Semaphore -> IO (Either SomeException Semaphore))
-> IO Semaphore -> IO (Either SomeException Semaphore)
forall a b. (a -> b) -> a -> b
$
    String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_str OpenSemFlags
flags FileMode
Posix.stdFileMode Int
init_toks
  Either (IO Semaphore) Semaphore
-> IO (Either (IO Semaphore) Semaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (IO Semaphore) Semaphore
 -> IO (Either (IO Semaphore) Semaphore))
-> Either (IO Semaphore) Semaphore
-> IO (Either (IO Semaphore) Semaphore)
forall a b. (a -> b) -> a -> b
$ case Either SomeException Semaphore
mb_sem of
    Left  SomeException
err -> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. a -> Either a b
Left (IO Semaphore -> Either (IO Semaphore) Semaphore)
-> IO Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ SomeException -> IO Semaphore
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
MC.throwM SomeException
err
    Right Semaphore
sem -> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. b -> Either a b
Right (Semaphore -> Either (IO Semaphore) Semaphore)
-> Semaphore -> Either (IO Semaphore) Semaphore
forall a b. (a -> b) -> a -> b
$ Semaphore -> Semaphore
mk_sem Semaphore
sem
#endif
  where
    sem_nm :: SemaphoreName
sem_nm = String -> SemaphoreName
SemaphoreName String
sem_str
    mk_sem :: Semaphore -> Semaphore
mk_sem Semaphore
sem =
      Semaphore
        { semaphore :: Semaphore
semaphore     = Semaphore
sem
        , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
sem_nm }

-- | Open a semaphore with the given name.
--
-- If no such semaphore exists, throws an error.
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore nm :: SemaphoreName
nm@(SemaphoreName String
sem_name) = do
#if defined(mingw32_HOST_OS)
  sem <- Win32.openSemaphore Win32.sEMAPHORE_ALL_ACCESS True sem_name
#else
  let
    flags :: OpenSemFlags
flags = Posix.OpenSemFlags
          { semCreate :: Bool
Posix.semCreate    = Bool
False
          , semExclusive :: Bool
Posix.semExclusive = Bool
False }
  Semaphore
sem <- String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_name OpenSemFlags
flags FileMode
Posix.stdFileMode Int
0
#endif
  Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Semaphore -> IO Semaphore) -> Semaphore -> IO Semaphore
forall a b. (a -> b) -> a -> b
$
    Semaphore
      { semaphore :: Semaphore
semaphore     = Semaphore
sem
      , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
nm }

-- | Indefinitely wait on a semaphore.
--
-- If you want to be able to cancel a wait operation, use
-- 'forkWaitOnSemaphoreInterruptible' instead.
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    () <$ Win32.waitForSingleObject (Win32.semaphoreHandle sem) Win32.iNFINITE
#else
  Semaphore -> IO ()
Posix.semThreadWait Semaphore
sem
#endif

-- | Try to obtain a token from the semaphore, without blocking.
--
-- Immediately returns 'False' if no resources are available.
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    return $ wait_res == Win32.wAIT_OBJECT_0
#else
  Semaphore -> IO Bool
Posix.semTryWait Semaphore
sem
#endif

-- | Release a semaphore: add @n@ to its internal counter.
--
-- No-op when `n <= 0`.
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
  = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise
  = IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
#if defined(mingw32_HOST_OS)
    void $ Win32.releaseSemaphore sem (fromIntegral n)
#else
    Int -> IO () -> IO ()
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
n (Semaphore -> IO ()
Posix.semPost Semaphore
sem)
#endif

-- | Destroy the given semaphore.
destroySemaphore :: Semaphore -> IO ()
destroySemaphore :: Semaphore -> IO ()
destroySemaphore Semaphore
sem =
#if defined(mingw32_HOST_OS)
  Win32.closeHandle (Win32.semaphoreHandle $ semaphore sem)
#else
  String -> IO ()
Posix.semUnlink (SemaphoreName -> String
getSemaphoreName (SemaphoreName -> String) -> SemaphoreName -> String
forall a b. (a -> b) -> a -> b
$ Semaphore -> SemaphoreName
semaphoreName Semaphore
sem)
#endif

-- | Query the current semaphore value (how many tokens it has available).
--
-- This is mainly for debugging use, as it is easy to introduce race conditions
-- when nontrivial program logic depends on the value returned by this function.
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    if wait_res == Win32.wAIT_OBJECT_0
      -- We were able to acquire a resource from the semaphore without waiting:
      -- release it immediately, thus obtaining the total number of available
      -- resources.
    then
      (+1) . fromIntegral <$> Win32.releaseSemaphore sem 1
    else
      return 0
#else
  Semaphore -> IO Int
Posix.semGetValue Semaphore
sem
#endif

-- | 'WaitId' stores the information we need to cancel a thread
-- which is waiting on a semaphore.
--
-- See 'forkWaitOnSemaphoreInterruptible' and 'interruptWaitOnSemaphore'.
data WaitId = WaitId { WaitId -> ThreadId
waitingThreadId :: ThreadId
#if defined(mingw32_HOST_OS)
                     , cancelHandle    :: Win32.HANDLE
#endif
                     }

-- | Spawn a thread that waits on the given semaphore.
--
-- In this thread, asynchronous exceptions will be masked.
--
-- The waiting operation can be interrupted using the
-- 'interruptWaitOnSemaphore' function.
forkWaitOnSemaphoreInterruptible
  :: Semaphore
  -> ( Either MC.SomeException Bool -> IO () ) -- ^ wait result action
  -> IO WaitId
forkWaitOnSemaphoreInterruptible :: Semaphore -> (Either SomeException Bool -> IO ()) -> IO WaitId
forkWaitOnSemaphoreInterruptible
  (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem })
  Either SomeException Bool -> IO ()
wait_result_action = do
#if defined(mingw32_HOST_OS)
    cancelHandle <- Win32.createEvent Nothing True False ""
#endif
    let
      interruptible_wait :: IO Bool
      interruptible_wait :: IO Bool
interruptible_wait =
#if defined(mingw32_HOST_OS)
        -- Windows: wait on both the handle used for cancelling the wait
        -- and on the semaphore.
          do
            wait_res <-
              Win32.waitForMultipleObjects
                [ Win32.semaphoreHandle sem
                , cancelHandle ]
                False -- False <=> WaitAny
                Win32.iNFINITE
            return $ wait_res == Win32.wAIT_OBJECT_0
            -- Only in the case that the wait result is WAIT_OBJECT_0 will
            -- we have succeeded in obtaining a token from the semaphore.
#else
        -- POSIX: use the 'semWaitInterruptible' interruptible FFI call
        -- that can be interrupted when we send a killThread signal.
          Semaphore -> IO Bool
Posix.semWaitInterruptible Semaphore
sem
#endif
    ThreadId
waitingThreadId <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Either SomeException Bool
wait_res <- IO Bool -> IO (Either SomeException Bool)
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try IO Bool
interruptible_wait
      Either SomeException Bool -> IO ()
wait_result_action Either SomeException Bool
wait_res
    WaitId -> IO WaitId
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (WaitId -> IO WaitId) -> WaitId -> IO WaitId
forall a b. (a -> b) -> a -> b
$ WaitId { ThreadId
waitingThreadId :: ThreadId
waitingThreadId :: ThreadId
.. }

-- | Interrupt a semaphore wait operation initiated by
-- 'forkWaitOnSemaphoreInterruptible'.
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore ( WaitId { ThreadId
waitingThreadId :: WaitId -> ThreadId
waitingThreadId :: ThreadId
.. } ) = do
#if defined(mingw32_HOST_OS)
  Win32.setEvent cancelHandle
    -- On Windows, we signal to stop waiting.
#endif
  ThreadId -> IO ()
killThread ThreadId
waitingThreadId
    -- On POSIX, killing the thread will cancel the wait on the semaphore
    -- due to the FFI call being interruptible ('semWaitInterruptible').

---------------------------------------
-- Abstract semaphores

-- | Abstraction over the operations of a semaphore.
data AbstractSem =
  AbstractSem
    { AbstractSem -> IO ()
acquireSem :: IO ()
    , AbstractSem -> IO ()
releaseSem :: IO ()
    }

withAbstractSem :: AbstractSem -> IO b -> IO b
withAbstractSem :: forall b. AbstractSem -> IO b -> IO b
withAbstractSem AbstractSem
sem = IO () -> IO () -> IO b -> IO b
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
MC.bracket_ (AbstractSem -> IO ()
acquireSem AbstractSem
sem) (AbstractSem -> IO ()
releaseSem AbstractSem
sem)

---------------------------------------
-- Utility

iToBase62 :: Int -> String
iToBase62 :: Int -> String
iToBase62 Int
m = Int -> String -> String
go Int
m' String
""
  where
    m' :: Int
m'
      | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
forall a. Bounded a => a
minBound
      = Int
forall a. Bounded a => a
maxBound
      | Bool
otherwise
      = Int -> Int
forall a. Num a => a -> a
abs Int
m
    go :: Int -> String -> String
go Int
n String
cs | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
62
            = let !c :: Char
c = Int -> Char
chooseChar62 Int
n
              in Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
cs
            | Bool
otherwise
            = let !(!Int
q, Int
r) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
62
                  !c :: Char
c       = Int -> Char
chooseChar62 Int
r
              in Int -> String -> String
go Int
q (Char
c Char -> String -> String
forall a. a -> [a] -> [a]
: String
cs)

    chooseChar62 :: Int -> Char
    {-# INLINE chooseChar62 #-}
    chooseChar62 :: Int -> Char
chooseChar62 (I# Int#
n) = Char# -> Char
C# (Addr# -> Int# -> Char#
indexCharOffAddr# Addr#
chars62 Int#
n)
    chars62 :: Addr#
chars62 = Addr#
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"#

random_strings :: IO (NonEmpty String)
random_strings :: IO (NonEmpty String)
random_strings = do
#if defined(mingw32_HOST_OS)
  Win32.FILETIME t <- Win32.getSystemTimeAsFileTime
#else
  CClock Int64
t <- ProcessTimes -> CClock
Posix.systemTime (ProcessTimes -> CClock) -> IO ProcessTimes -> IO CClock
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ProcessTimes
Posix.getProcessTimes
#endif
  NonEmpty String -> IO (NonEmpty String)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty String -> IO (NonEmpty String))
-> NonEmpty String -> IO (NonEmpty String)
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> NonEmpty Int -> NonEmpty String
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ( \ Int
i -> Int -> String
iToBase62 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
t) ) (Int
0 Int -> [Int] -> NonEmpty Int
forall a. a -> [a] -> NonEmpty a
:| [Int
1..])