{-# LANGUAGE CPP #-}

module System.Semaphore
  ( -- * System semaphores

    -- $server-vs-client

    ClientSemaphore, ServerSemaphore, SemaphoreName(..)

    -- | The name of a client semaphore
  , clientSemaphoreName
  , semaphoreIdentifier
    -- | Retrieve the client semaphore corresponding to a server semaphore
  , serverClientSemaphore

    -- ** Creating a semaphore
  , createSemaphore, freshSemaphore

    -- ** Opening a semaphore
  , SemaphoreToken
  , openSemaphore
  , SemaphoreIdentifier, parseSemaphoreIdentifier
  , SemaphoreProtocolVersion(..)
  , semaphoreVersion
  , versionsAreCompatible
  , SemaphoreError(..)

    -- ** Requesting a token
  , waitOnSemaphore, tryWaitOnSemaphore
  , withSemaphoreToken
  , getSemaphoreValue

    -- ** Releasing resources
  , releaseSemaphoreToken

    -- $destroying
  , destroyClientSemaphore, destroyServerSemaphore

  -- * Abstract semaphores
  , AbstractSem(..)
  , withAbstractSem
  ) where

{- $server-vs-client

Since version 2 of @semaphore-compat@, we distinguish between two kinds
of semaphores:

  - When a jobserver creates a semaphore via 'createSemaphore', it obtains
    a 'ServerSemaphore'.
  - Jobclients (which open a pre-existing semaphore via 'openSemaphore')
    obtain a 'ClientSemaphore'.

When the jobserver wants to also act as a jobclient, it can use
'serverClientSemaphore' to obtain the 'ClientSemaphore' corresponding
to its 'ServerSemaphore'.

This architecture allows the jobserver to keep full accounting of
semaphore resources held by all clients.
-}

{- $destroying

Destroying a semaphore releases the implicit token held by the
semaphore.

For example, when @cabal-install@ invokes @ghc@, the @ghc@ process
automatically has one semaphore token at the start (the implicit
token), and can request further tokens with 'waitOnSemaphore'. The
tokens acquired by 'waitOnSemaphore' are released by
'releaseSemaphoreToken', while the implicit token is released by
'destroyClientSemaphore'.
-}

-- base
import Data.List.NonEmpty ( NonEmpty(..) )

-- exceptions
import qualified Control.Monad.Catch as MC

import System.Semaphore.Internal.Common

#if defined(mingw32_HOST_OS)
import System.Semaphore.Internal.Win32
#elif defined(wasm32_HOST_ARCH) || defined(javascript_HOST_ARCH)
import System.Semaphore.Internal.Unsupported
#else
import System.Semaphore.Internal.Posix
#endif

---------------------------------------
-- Version compatibility

-- | Check whether two semaphore protocol versions are compatible.
-- Only identical versions are compatible.
versionsAreCompatible :: SemaphoreProtocolVersion -> SemaphoreProtocolVersion -> Bool
versionsAreCompatible :: SemaphoreProtocolVersion -> SemaphoreProtocolVersion -> Bool
versionsAreCompatible SemaphoreProtocolVersion
a SemaphoreProtocolVersion
b = SemaphoreProtocolVersion
a SemaphoreProtocolVersion -> SemaphoreProtocolVersion -> Bool
forall a. Eq a => a -> a -> Bool
== SemaphoreProtocolVersion
b

---------------------------------------
-- System-specific semaphores

-- | Create a new semaphore with the given label and initial token count.
--
-- On POSIX, crash recovery is automatic: disconnected clients' tokens
-- are returned to the pool.  On Windows, tokens held by a crashed
-- client are permanently lost.
createSemaphore :: String -- ^ label
                -> Int    -- ^ number of tokens on the semaphore
                -> IO (Either SemaphoreError ServerSemaphore)
createSemaphore :: String -> Int -> IO (Either SemaphoreError ServerSemaphore)
createSemaphore String
label Int
init_toks = do
  let sem_nm :: SemaphoreName
sem_nm = SemaphoreName
        { semaphoreProtocolVersion :: SemaphoreProtocolVersion
semaphoreProtocolVersion = SemaphoreProtocolVersion
semaphoreVersion
        , unversionedSemaphoreNameString :: String
unversionedSemaphoreNameString = String
label
        }
  SemaphoreName -> Int -> IO (Either SemaphoreError ServerSemaphore)
create_sem SemaphoreName
sem_nm Int
init_toks

-- | Create a fresh semaphore with a unique name and the given token count.
--
-- The name is derived from the given prefix with a random suffix.
freshSemaphore :: String -- ^ label prefix
               -> Int    -- ^ number of tokens on the semaphore
               -> IO (Either SemaphoreError ServerSemaphore)
freshSemaphore :: String -> Int -> IO (Either SemaphoreError ServerSemaphore)
freshSemaphore String
prefix Int
init_toks = do
  seed <- IO Int
getTimeSeed
  go 0 (seedStrings seed)
  where
    go :: Int -> NonEmpty String -> IO (Either SemaphoreError ServerSemaphore)
    go :: Int
-> NonEmpty String -> IO (Either SemaphoreError ServerSemaphore)
go Int
i (String
suffix :| [String]
suffs) = do
      let sem_str :: String
sem_str = String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
suffix
          sem_nm :: SemaphoreName
sem_nm  = SemaphoreName
            { semaphoreProtocolVersion :: SemaphoreProtocolVersion
semaphoreProtocolVersion = SemaphoreProtocolVersion
semaphoreVersion
            , unversionedSemaphoreNameString :: String
unversionedSemaphoreNameString = String
sem_str
            }
      mb_sem <- SemaphoreName -> Int -> IO (Either SemaphoreError ServerSemaphore)
create_sem SemaphoreName
sem_nm Int
init_toks
      case mb_sem of
        Right ServerSemaphore
sem -> Either SemaphoreError ServerSemaphore
-> IO (Either SemaphoreError ServerSemaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ServerSemaphore -> Either SemaphoreError ServerSemaphore
forall a b. b -> Either a b
Right ServerSemaphore
sem)
        Left  SemaphoreError
err
          | String
next : [String]
nexts <- [String]
suffs
          , Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
32 -- give up after 32 attempts
          -> Int
-> NonEmpty String -> IO (Either SemaphoreError ServerSemaphore)
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (String
next String -> [String] -> NonEmpty String
forall a. a -> [a] -> NonEmpty a
:| [String]
nexts)
          | Bool
otherwise
          -> Either SemaphoreError ServerSemaphore
-> IO (Either SemaphoreError ServerSemaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SemaphoreError -> Either SemaphoreError ServerSemaphore
forall a b. a -> Either a b
Left SemaphoreError
err)

-- | Open a semaphore from its 'SemaphoreIdentifier'.
--
-- The identifier should normally begin with a version prefix @v\<N\>-@.
-- An unversioned identifier is treated as v1 for backwards
-- compatibility. Returns @Left SemaphoreIncompatibleVersion@ if the
-- identifier's protocol version is not compatible with this build of
-- @semaphore-compat@.
openSemaphore :: SemaphoreIdentifier -> IO (Either SemaphoreError ClientSemaphore)
openSemaphore :: String -> IO (Either SemaphoreError ClientSemaphore)
openSemaphore String
ident =
  case String -> Maybe SemaphoreName
parseSemaphoreIdentifier String
ident of
    Maybe SemaphoreName
Nothing
      | SemaphoreProtocolVersion -> SemaphoreProtocolVersion -> Bool
versionsAreCompatible SemaphoreProtocolVersion
v1 SemaphoreProtocolVersion
semaphoreVersion ->
          SemaphoreName -> IO (Either SemaphoreError ClientSemaphore)
open_sem_raw (SemaphoreName { semaphoreProtocolVersion :: SemaphoreProtocolVersion
semaphoreProtocolVersion = SemaphoreProtocolVersion
v1
                                      , unversionedSemaphoreNameString :: String
unversionedSemaphoreNameString = String
ident })
      | Bool
otherwise ->
          Either SemaphoreError ClientSemaphore
-> IO (Either SemaphoreError ClientSemaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SemaphoreError ClientSemaphore
 -> IO (Either SemaphoreError ClientSemaphore))
-> Either SemaphoreError ClientSemaphore
-> IO (Either SemaphoreError ClientSemaphore)
forall a b. (a -> b) -> a -> b
$ SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. a -> Either a b
Left (SemaphoreError -> Either SemaphoreError ClientSemaphore)
-> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. (a -> b) -> a -> b
$ SemaphoreProtocolVersion -> SemaphoreError
semVerError SemaphoreProtocolVersion
v1
    Just SemaphoreName
nm
      | Bool -> Bool
not (SemaphoreProtocolVersion -> SemaphoreProtocolVersion -> Bool
versionsAreCompatible (SemaphoreName -> SemaphoreProtocolVersion
semaphoreProtocolVersion SemaphoreName
nm) SemaphoreProtocolVersion
semaphoreVersion) ->
          Either SemaphoreError ClientSemaphore
-> IO (Either SemaphoreError ClientSemaphore)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SemaphoreError ClientSemaphore
 -> IO (Either SemaphoreError ClientSemaphore))
-> Either SemaphoreError ClientSemaphore
-> IO (Either SemaphoreError ClientSemaphore)
forall a b. (a -> b) -> a -> b
$ SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. a -> Either a b
Left (SemaphoreError -> Either SemaphoreError ClientSemaphore)
-> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. (a -> b) -> a -> b
$ SemaphoreProtocolVersion -> SemaphoreError
semVerError (SemaphoreName -> SemaphoreProtocolVersion
semaphoreProtocolVersion SemaphoreName
nm)
      | Bool
otherwise ->
          SemaphoreName -> IO (Either SemaphoreError ClientSemaphore)
open_sem_raw SemaphoreName
nm
  where
    v1 :: SemaphoreProtocolVersion
v1 = Int -> SemaphoreProtocolVersion
SemaphoreProtocolVersion Int
1
    semVerError :: SemaphoreProtocolVersion -> SemaphoreError
semVerError SemaphoreProtocolVersion
ver = SemaphoreProtocolVersion
-> SemaphoreProtocolVersion -> SemaphoreError
SemaphoreIncompatibleVersion SemaphoreProtocolVersion
ver SemaphoreProtocolVersion
semaphoreVersion

-- | Acquire a token, run an action, then release the token. Exception safe.
withSemaphoreToken :: ClientSemaphore -> (SemaphoreToken -> IO a) -> IO a
withSemaphoreToken :: forall a. ClientSemaphore -> (SemaphoreToken -> IO a) -> IO a
withSemaphoreToken ClientSemaphore
sem = IO SemaphoreToken
-> (SemaphoreToken -> IO ()) -> (SemaphoreToken -> IO a) -> IO a
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> (a -> m c) -> (a -> m b) -> m b
MC.bracket (HasCallStack => ClientSemaphore -> IO SemaphoreToken
ClientSemaphore -> IO SemaphoreToken
waitOnSemaphore ClientSemaphore
sem) HasCallStack => SemaphoreToken -> IO ()
SemaphoreToken -> IO ()
releaseSemaphoreToken

seedStrings :: Int -> NonEmpty String
seedStrings :: Int -> NonEmpty String
seedStrings Int
seed = (Int -> String) -> NonEmpty Int -> NonEmpty String
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ( \ Int
i -> Int -> String
iToBase62 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
seed) ) (Int
0 Int -> [Int] -> NonEmpty Int
forall a. a -> [a] -> NonEmpty a
:| [Int
1..])

---------------------------------------
-- Abstract semaphores

-- | Abstraction over the operations of a semaphore.
data AbstractSem =
  AbstractSem
    { AbstractSem -> IO ()
acquireSem :: IO ()
    , AbstractSem -> IO ()
releaseSem :: IO ()
    }

withAbstractSem :: AbstractSem -> IO b -> IO b
withAbstractSem :: forall b. AbstractSem -> IO b -> IO b
withAbstractSem AbstractSem
s = IO () -> IO () -> IO b -> IO b
forall (m :: * -> *) a c b.
(HasCallStack, MonadMask m) =>
m a -> m c -> m b -> m b
MC.bracket_ (AbstractSem -> IO ()
acquireSem AbstractSem
s) (AbstractSem -> IO ()
releaseSem AbstractSem
s)