{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module System.Semaphore.Internal.Posix
  ( ClientSemaphore(..), ServerSemaphore(..)
  , SemaphoreToken(..)
  , create_sem, open_sem_raw
  , waitOnSemaphore, tryWaitOnSemaphore
  , releaseSemaphoreToken
  , destroyClientSemaphore, destroyServerSemaphore
  , getSemaphoreValue
  , getTimeSeed
  ) where

-- base
import Control.Concurrent
  ( ThreadId, forkIOWithUnmask, killThread )
import Control.Concurrent.MVar
  ( MVar, mkWeakMVar, newEmptyMVar, newMVar, putMVar
  , readMVar, takeMVar, tryTakeMVar )
import Control.Exception ( IOException )
import Control.Monad
import Data.Word ( Word8 )
import Data.Bits ( xor )
import Foreign.C.Error ( Errno(Errno), eCONNABORTED )
import GHC.Clock ( getMonotonicTimeNSec )
import GHC.IO.Exception ( ioe_errno )
import GHC.Stack ( HasCallStack )
import System.IO.Error ( isFullError )

-- exceptions
import qualified Control.Monad.Catch as MC

-- stm
import Control.Concurrent.STM
  ( TVar, atomically, newTVarIO, readTVar, readTVarIO
  , modifyTVar', writeTVar, retry )

-- directory
import System.Directory ( doesPathExist )

-- unix
import System.Posix.IO ( closeFd, createPipe )
import System.Posix.Files ( removeLink )
import System.Posix.Types ( Fd )
import System.Posix.Process ( getProcessID )

import System.Semaphore.Internal.Common
import System.Semaphore.Internal.DomainSocket
  ( connectDomainSocket, listenDomainSocket
  , pollAcceptSocket, AcceptResult(..)
  , fdReadByte, fdWriteByte
  , fdShutdown )

-- | A semaphore identity (name + socket path).
-- Each operation that needs a connection opens one internally.
data ClientSemaphore =
  ClientSemaphore
    { ClientSemaphore -> SemaphoreName
clientSemaphoreName :: !SemaphoreName
    , ClientSemaphore -> FilePath
semSocketPath :: !FilePath
    }

-- | A held semaphore token, bound to one acquired resource.
--
-- If all references to the 'SemaphoreToken' are dropped without being
-- released, a finalizer closes the underlying connection and the server
-- returns the token to the pool.  Use 'releaseSemaphoreToken' or
-- 'System.Semaphore.withSemaphoreToken' for prompt release rather than
-- relying on GC timing.
--
-- The fd is held in an internal 'MVar' so 'releaseSemaphoreToken' takes
-- ownership atomically: a second (erroneous) release is a safe no-op.
newtype SemaphoreToken = SemaphoreToken
  { SemaphoreToken -> MVar Fd
tokenFdLock :: MVar Fd
  }

-- | A server-side semaphore (owns the server thread, listen socket, and token pool).
data ServerSemaphore = ServerSemaphore
  { ServerSemaphore -> ClientSemaphore
serverClientSemaphore  :: !ClientSemaphore
  , ServerSemaphore -> ThreadId
serverThreadId   :: !ThreadId
  , ServerSemaphore -> TVar Int
serverPool       :: !(TVar Int)
  , ServerSemaphore -> MVar ServerState
serverState      :: !(MVar ServerState) -- ^ MVar is emptied when resources/fds are freed to prevent double close
  }

data ServerState = ServerState
  { ServerState -> Fd
serverListenFd :: !Fd
  , ServerState -> Fd
serverCancelFd :: !Fd
    -- ^ Write end of the cancel pipe.  Writing a byte signals the
    -- server loop to exit its poll-accept.
  }

create_sem :: SemaphoreName -> Int -> IO (Either SemaphoreError ServerSemaphore)
create_sem :: SemaphoreName -> Int -> IO (Either SemaphoreError ServerSemaphore)
create_sem SemaphoreName
sem_nm Int
init_toks = do
  mb_res <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO ServerSemaphore -> IO (Either IOException ServerSemaphore))
-> IO ServerSemaphore -> IO (Either IOException ServerSemaphore)
forall a b. (a -> b) -> a -> b
$ IO ServerSemaphore -> IO ServerSemaphore
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO ServerSemaphore -> IO ServerSemaphore)
-> IO ServerSemaphore -> IO ServerSemaphore
forall a b. (a -> b) -> a -> b
$ do
    socketPath <- SemaphoreName -> IO FilePath
getSemaphoreSocketPath SemaphoreName
sem_nm
    listenFd <- listenDomainSocket socketPath
    -- ^^ creates the socket file, unlink it too on failure.
    let cleanupListen = do
          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
listenFd
          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ FilePath -> IO ()
removeLink FilePath
socketPath
    flip MC.onException cleanupListen $ do
      pool <- newTVarIO init_toks
      (cancelRd, cancelWr) <- createPipe
      tid <- forkIOWithUnmask $ \forall a. IO a -> IO a
unmask ->
               IO () -> IO ()
forall a. IO a -> IO a
unmask (TVar Int -> Fd -> Fd -> IO ()
serverLoop TVar Int
pool Fd
listenFd Fd
cancelRd)
                 IO () -> IO () -> IO ()
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
`MC.finally` Fd -> IO ()
closeFd Fd
cancelRd
      stateVar <- newMVar ServerState
        { serverListenFd = listenFd
        , serverCancelFd = cancelWr
        }
      return ServerSemaphore
        { serverClientSemaphore = ClientSemaphore { clientSemaphoreName = sem_nm
                                            , semSocketPath = socketPath }
        , serverThreadId  = tid
        , serverPool      = pool
        , serverState     = stateVar
        }
  return $ case mb_res of
    Left  IOException
err -> SemaphoreError -> Either SemaphoreError ServerSemaphore
forall a b. a -> Either a b
Left (SemaphoreError -> Either SemaphoreError ServerSemaphore)
-> SemaphoreError -> Either SemaphoreError ServerSemaphore
forall a b. (a -> b) -> a -> b
$ IOException -> SemaphoreError
SemaphoreOtherError IOException
err
    Right ServerSemaphore
sem -> ServerSemaphore -> Either SemaphoreError ServerSemaphore
forall a b. b -> Either a b
Right ServerSemaphore
sem

open_sem_raw :: SemaphoreName -> IO (Either SemaphoreError ClientSemaphore)
open_sem_raw :: SemaphoreName -> IO (Either SemaphoreError ClientSemaphore)
open_sem_raw SemaphoreName
nm = do
  mb_res <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO (FilePath, Bool) -> IO (Either IOException (FilePath, Bool)))
-> IO (FilePath, Bool) -> IO (Either IOException (FilePath, Bool))
forall a b. (a -> b) -> a -> b
$ do
    socketPath <- SemaphoreName -> IO FilePath
getSemaphoreSocketPath SemaphoreName
nm
    exists <- doesPathExist socketPath
    return (socketPath, exists)
  return $ case mb_res of
    Left  IOException
err             -> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. a -> Either a b
Left (SemaphoreError -> Either SemaphoreError ClientSemaphore)
-> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. (a -> b) -> a -> b
$ IOException -> SemaphoreError
SemaphoreOtherError IOException
err
    Right (FilePath
_, Bool
False)      -> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. a -> Either a b
Left (SemaphoreError -> Either SemaphoreError ClientSemaphore)
-> SemaphoreError -> Either SemaphoreError ClientSemaphore
forall a b. (a -> b) -> a -> b
$ FilePath -> SemaphoreError
SemaphoreDoesNotExist (SemaphoreName -> FilePath
semaphoreIdentifier SemaphoreName
nm)
    Right (FilePath
socketPath, Bool
_) -> ClientSemaphore -> Either SemaphoreError ClientSemaphore
forall a b. b -> Either a b
Right (ClientSemaphore -> Either SemaphoreError ClientSemaphore)
-> ClientSemaphore -> Either SemaphoreError ClientSemaphore
forall a b. (a -> b) -> a -> b
$
      ClientSemaphore
        { clientSemaphoreName :: SemaphoreName
clientSemaphoreName = SemaphoreName
nm
        , semSocketPath :: FilePath
semSocketPath = FilePath
socketPath
        }

-- | Acquire a token from the semaphore, blocking until one is available.
--
-- This operation is interruptible: it can be cancelled by
-- 'Control.Concurrent.throwTo', 'Control.Concurrent.killThread', etc. If
-- interrupted, any transiently acquired token is automatically returned to the
-- pool.
--
-- For prompt and predictable release of resources, callers should use
-- 'System.Semaphore.withSemaphoreToken' or `releaseSemaphoreToken'.
waitOnSemaphore :: HasCallStack => ClientSemaphore -> IO SemaphoreToken
waitOnSemaphore :: HasCallStack => ClientSemaphore -> IO SemaphoreToken
waitOnSemaphore ClientSemaphore
sem = do
  resultVar <- IO (MVar (Either SomeException Word8))
forall a. IO (MVar a)
newEmptyMVar
  -- Mask exceptions until we get to the interruptible takeMVar
  MC.mask_ $ do
    fd <- connectDomainSocket (semSocketPath sem)
    -- The read() runs in a forked thread
    workerTid <- forkIOWithUnmask $ \forall a. IO a -> IO a
_ -> do
      res <- forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @MC.SomeException (IO Word8 -> IO (Either SomeException Word8))
-> IO Word8 -> IO (Either SomeException Word8)
forall a b. (a -> b) -> a -> b
$ do
        HasCallStack => Fd -> Word8 -> IO ()
Fd -> Word8 -> IO ()
fdWriteByte Fd
fd Word8
CmdWait
        HasCallStack => Fd -> IO Word8
Fd -> IO Word8
fdReadByte Fd
fd
      putMVar resultVar res
    -- uninterruptibleMask_: killThread is interruptible, and a
    -- second async between killThread and closeFd would leak fd.
    let cleanup = IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.uninterruptibleMask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          -- shutdown(SHUT_RDWR) causes the worker's read() to return EOF immediately
          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
fdShutdown Fd
fd
          -- Wait for the worker to exit before closing the fd.
          -- this prevents double close bugs
          ThreadId -> IO ()
killThread ThreadId
workerTid
          -- Finally close the fd
          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd
    -- We achieve interruptiblity by relying on the interruptiblity of takeMVar
    -- The worker thread is blocked on read(), fdShutdown in cleanup interrupts
    -- it if we are interrupt by an async exception here.
    res <- takeMVar resultVar `MC.onException` cleanup
    case res of
      Right Word8
resp
        | Word8
resp Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
RspOk -> Fd -> IO SemaphoreToken
mkToken Fd
fd
        | Bool
otherwise     -> do IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd
                              FilePath -> IO SemaphoreToken
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO SemaphoreToken) -> FilePath -> IO SemaphoreToken
forall a b. (a -> b) -> a -> b
$ FilePath
"semaphore-compat: unexpected response in waitOnSemaphore: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Word8 -> FilePath
forall a. Show a => a -> FilePath
show Word8
resp
      Left SomeException
e            -> do IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd
                              SomeException -> IO SemaphoreToken
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
MC.throwM SomeException
e

-- | Try to acquire a token from the semaphore without blocking.
--
-- Returns @Just token@ if a token was available, @Nothing@ otherwise.
--
-- Not interruptible, but this shouldn't block for long as the server is
-- supposed to respond immediately.
tryWaitOnSemaphore :: HasCallStack => ClientSemaphore -> IO (Maybe SemaphoreToken)
tryWaitOnSemaphore :: HasCallStack => ClientSemaphore -> IO (Maybe SemaphoreToken)
tryWaitOnSemaphore ClientSemaphore
sem =
  IO (Maybe SemaphoreToken) -> IO (Maybe SemaphoreToken)
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO (Maybe SemaphoreToken) -> IO (Maybe SemaphoreToken))
-> IO (Maybe SemaphoreToken) -> IO (Maybe SemaphoreToken)
forall a b. (a -> b) -> a -> b
$ do
    fd <- FilePath -> IO Fd
connectDomainSocket (ClientSemaphore -> FilePath
semSocketPath ClientSemaphore
sem)
    resp <- flip MC.onException (closeFd fd) $ do
      fdWriteByte fd CmdTryWait
      fdReadByte fd
    case resp of
      Word8
RspOk -> SemaphoreToken -> Maybe SemaphoreToken
forall a. a -> Maybe a
Just (SemaphoreToken -> Maybe SemaphoreToken)
-> IO SemaphoreToken -> IO (Maybe SemaphoreToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fd -> IO SemaphoreToken
mkToken Fd
fd
      Word8
_     -> do IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd
                  Maybe SemaphoreToken -> IO (Maybe SemaphoreToken)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe SemaphoreToken
forall a. Maybe a
Nothing

mkToken :: Fd -> IO SemaphoreToken
mkToken :: Fd -> IO SemaphoreToken
mkToken Fd
fd = do
  fdVar <- Fd -> IO (MVar Fd)
forall a. a -> IO (MVar a)
newMVar Fd
fd
  _ <- mkWeakMVar fdVar $ do
    mb <- tryTakeMVar fdVar
    case mb of
      Maybe Fd
Nothing  -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()       -- already released
      -- Closing the fd triggers the server's disconnect handling,
      -- which returns the token to the pool.
      Just Fd
fd' -> IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd'
  return (SemaphoreToken fdVar)

-- | Release a semaphore token, returning it to the pool.
--
-- Sends a release command on the token's connection, then closes it.
-- Idempotent: a second call on the same token is a safe no-op.
--
-- Not interruptible; only returns when the release has succeeded.
releaseSemaphoreToken :: HasCallStack => SemaphoreToken -> IO ()
releaseSemaphoreToken :: HasCallStack => SemaphoreToken -> IO ()
releaseSemaphoreToken (SemaphoreToken MVar Fd
fdVar) =
  IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    mb <- MVar Fd -> IO (Maybe Fd)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar Fd
fdVar
    case mb of
      Maybe Fd
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()  -- already released
      Just Fd
fd -> do
        resp <- (do HasCallStack => Fd -> Word8 -> IO ()
Fd -> Word8 -> IO ()
fdWriteByte Fd
fd Word8
CmdRelease
                    HasCallStack => Fd -> IO Word8
Fd -> IO Word8
fdReadByte Fd
fd
                ) IO Word8 -> IO () -> IO Word8
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
`MC.finally` (IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd)
        case resp of
          Word8
RspOk   -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          -- myCount <= 0 on the server means the token is effectively
          -- already released; stay idempotent.
          Word8
RspFail -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Word8
_       -> FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"semaphore-compat: unexpected response in releaseSemaphoreToken"

-- | Destroy a client-side semaphore.
--
-- On POSIX this is a no-op: 'ClientSemaphore' holds no live connection.
destroyClientSemaphore :: ClientSemaphore -> IO ()
destroyClientSemaphore :: ClientSemaphore -> IO ()
destroyClientSemaphore ClientSemaphore
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Destroy a server-side semaphore.
--
-- Idempotent. Subsequent calls after the first one are no-ops
-- Not interruptible. Only returns when the server and all resources
-- have been cleaned up
destroyServerSemaphore :: ServerSemaphore -> IO ()
destroyServerSemaphore :: ServerSemaphore -> IO ()
destroyServerSemaphore ServerSemaphore
server = IO () -> IO ()
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.uninterruptibleMask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  -- we justify the uninterruptibleMask_ here by analysis of the server, which must exit
  -- in a bounded amount of time once it receives the cancel signal on serverCancelFd
  -- It has a bounded number of child threads it needs to cleanup before returning,
  -- but otherwise it should exit promptly.
  --
  -- Without uninterruptibleMask_, we potentially leak serverListenFd and path if an
  -- exception arrives when we are in `killThread`.
  mbState <- MVar ServerState -> IO (Maybe ServerState)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar (ServerSemaphore -> MVar ServerState
serverState ServerSemaphore
server)
  case mbState of
    Maybe ServerState
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()  -- already destroyed
    Just ServerState{Fd
serverListenFd :: ServerState -> Fd
serverCancelFd :: ServerState -> Fd
serverListenFd :: Fd
serverCancelFd :: Fd
..} -> do
      let path :: FilePath
path = ClientSemaphore -> FilePath
semSocketPath (ClientSemaphore -> FilePath) -> ClientSemaphore -> FilePath
forall a b. (a -> b) -> a -> b
$ ServerSemaphore -> ClientSemaphore
serverClientSemaphore ServerSemaphore
server
      -- Signal the server loop to exit pollAcceptSocket, then wait for it.
      IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ HasCallStack => Fd -> Word8 -> IO ()
Fd -> Word8 -> IO ()
fdWriteByte Fd
serverCancelFd Word8
0
      IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
serverCancelFd
      ThreadId -> IO ()
killThread (ServerSemaphore -> ThreadId
serverThreadId ServerSemaphore
server)
      IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
serverListenFd
      IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ FilePath -> IO ()
removeLink FilePath
path

-- | Query the current semaphore value (how many tokens it has available).
--
-- This is mainly for debugging use, as it is easy to introduce race conditions
-- when nontrivial program logic depends on the value returned by this function.
getSemaphoreValue :: ServerSemaphore -> IO Int
getSemaphoreValue :: ServerSemaphore -> IO Int
getSemaphoreValue ServerSemaphore
server = TVar Int -> IO Int
forall a. TVar a -> IO a
readTVarIO (ServerSemaphore -> TVar Int
serverPool ServerSemaphore
server)

getTimeSeed :: IO Int
getTimeSeed :: IO Int
getTimeSeed = do
  ns <- IO Word64
getMonotonicTimeNSec
  pid <- getProcessID
  return $ fromIntegral ns `xor` fromIntegral pid

---------------------------------------
-- Server (Unix domain socket)
--
-- The server manages a shared token pool (TVar Int) and accepts multiple
-- client connections, each served on its own thread.
--
-- Protocol (SOCK_STREAM, one byte per command):
--
--   "-"  Wait (blocking acquire).  Decrements semaphore; replies ".".
--   "?"  Try-wait.  Decrements if positive and replies "."; otherwise replies "!".
--   "+"  Release.  Increments pool; replies ".".  Rejected with "!" if
--        myCount <= 0 (client has not acquired any tokens on this connection).
--
-- Unrecognised bytes are rejected with RspFail.
--
-- Per-connection token tracking: the server counts tokens held by each
-- connection (myCount).  On disconnect (EOF / ResourceVanished) held
-- tokens are returned to the pool, so a crashing client cannot leak
-- tokens.
--
-- Connections are tied to SemaphoreToken, so each token (obtained by waitOnSemaphore)
-- has its own connection
--
---------------------------------------
-- Protocol byte constants

pattern CmdWait, CmdTryWait, CmdRelease :: Word8
pattern $mCmdWait :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bCmdWait :: Word8
CmdWait    = 0x2D -- '-'
pattern $mCmdTryWait :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bCmdTryWait :: Word8
CmdTryWait = 0x3F -- '?'
pattern $mCmdRelease :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bCmdRelease :: Word8
CmdRelease = 0x2B -- '+'

pattern RspOk, RspFail :: Word8
pattern $mRspOk :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bRspOk :: Word8
RspOk   = 0x2E -- '.'
pattern $mRspFail :: forall {r}. Word8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bRspFail :: Word8
RspFail = 0x21 -- '!'

-- | Children of the server, to be cleaned up
-- childFdLock is full when the 'Fd' is still known to be valid,
-- and subject to being shutdown or closed
data Child = Child
  { Child -> MVar ThreadId
childThread :: !(MVar ThreadId)
  , Child -> MVar Fd
childFdLock :: !(MVar Fd)
  }

serverLoop :: TVar Int -> Fd -> Fd -> IO ()
serverLoop :: TVar Int -> Fd -> Fd -> IO ()
serverLoop TVar Int
pool Fd
listenFd Fd
cancelFd = do
    children <- [Child] -> IO (TVar [Child])
forall a. a -> IO (TVar a)
newTVarIO ([] :: [Child])
    loop children `MC.finally` killChildren children
  where
    loop :: TVar [Child] -> IO ()
loop TVar [Child]
children = do
      continueLoop <- IO Bool -> IO Bool
forall (m :: * -> *) a. (HasCallStack, MonadMask m) => m a -> m a
MC.mask_ (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
        r <- IO AcceptResult
acceptWithRetry
        case r of
          AcceptResult
AcceptCancelled     -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          AcceptedFd Fd
clientFd -> do
            TVar [Child] -> Fd -> IO ()
forkServeChild TVar [Child]
children Fd
clientFd
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      when continueLoop $ loop children

    -- pollAcceptSocket only returns if the accept succeeded, or cancellation
    -- was signalled via the cancel pipe.
    acceptWithRetry :: IO AcceptResult
    acceptWithRetry :: IO AcceptResult
acceptWithRetry = Fd -> Fd -> IO AcceptResult
pollAcceptSocket Fd
listenFd Fd
cancelFd IO AcceptResult
-> (IOException -> IO AcceptResult) -> IO AcceptResult
forall e a.
(HasCallStack, Exception e) =>
IO a -> (e -> IO a) -> IO a
forall (m :: * -> *) e a.
(MonadCatch m, HasCallStack, Exception e) =>
m a -> (e -> m a) -> m a
`MC.catch` IOException -> IO AcceptResult
handleIOError

    -- Retry accept on transient errors.
    handleIOError :: IOException -> IO AcceptResult
    handleIOError :: IOException -> IO AcceptResult
handleIOError IOException
e
      -- 'isFullError' catches ResourceExhausted (EMFILE/ENFILE/ENOBUFS/ENOMEM and more that accept doesn't produce but are harmless to retry).
      | IOException -> Bool
isFullError IOException
e                             = IO AcceptResult
acceptWithRetry
      -- ECONNABORTED is also transient but categorised as OtherError, we additionaly match that.
      | Just CInt
err <- IOException -> Maybe CInt
ioe_errno IOException
e
      , CInt -> Errno
Errno CInt
err Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eCONNABORTED                 = IO AcceptResult
acceptWithRetry
      -- EINTR is absorbed in hs_poll_accept.
      -- everything else (EBADF, EINVAL, ENOTSOCK, ...) is rethrown.
      | Bool
otherwise                                 = IOException -> IO AcceptResult
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
MC.throwM IOException
e

    forkServeChild :: TVar [Child] -> Fd -> IO ()
forkServeChild TVar [Child]
children Fd
clientFd = do
      fdLock <- Fd -> IO (MVar Fd)
forall a. a -> IO (MVar a)
newMVar Fd
clientFd
      tidVar <- newEmptyMVar
      let child = MVar ThreadId -> MVar Fd -> Child
Child MVar ThreadId
tidVar MVar Fd
fdLock
      atomically $ modifyTVar' children (child :)
      childTid <- forkIOWithUnmask $ \forall a. IO a -> IO a
unmask ->
        (forall a. IO a -> IO a)
-> TVar Int -> TVar [Child] -> Fd -> Child -> IO ()
serve IO a -> IO a
forall a. IO a -> IO a
unmask TVar Int
pool TVar [Child]
children Fd
clientFd Child
child
      putMVar tidVar childTid

    -- Interrupt all children blocked on a read from the FD, then kill them.
    killChildren :: TVar [Child] -> IO ()
    killChildren :: TVar [Child] -> IO ()
killChildren TVar [Child]
children = do
      kids <- TVar [Child] -> IO [Child]
forall a. TVar a -> IO a
readTVarIO TVar [Child]
children
      forM_ kids $ \Child
child -> do
        -- If the child is in a read(), interrupt it
        -- by calling 'fdShutdown'.
        mb <- MVar Fd -> IO (Maybe Fd)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar (Child -> MVar Fd
childFdLock Child
child)
        case mb of
          Just Fd
cfd -> do
            IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
fdShutdown Fd
cfd
            MVar Fd -> Fd -> IO ()
forall a. MVar a -> a -> IO ()
putMVar (Child -> MVar Fd
childFdLock Child
child) Fd
cfd
          Maybe Fd
Nothing  -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()  -- serve thread already closing
      -- No children blocked on read: terminate them all.
      -- childThread was filled by the parent before mask exit; readMVar of
      -- a definitely-full MVar is non-interruptible.
      forM_ kids $ \Child
child -> do
        tid <- MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
readMVar (Child -> MVar ThreadId
childThread Child
child)
        killThread tid

-- | Per-connection server loop.
serve :: (forall a. IO a -> IO a)
      -> TVar Int -> TVar [Child] -> Fd -> Child
      -> IO ()
serve :: (forall a. IO a -> IO a)
-> TVar Int -> TVar [Child] -> Fd -> Child -> IO ()
serve forall a. IO a -> IO a
restore TVar Int
pool TVar [Child]
children Fd
fd (Child MVar ThreadId
_ MVar Fd
fdLock) = do
    myCount <- Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO (Int
0 :: Int)
    let loop = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO b) -> IO () -> IO b
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b.
HasCallStack =>
((forall a. IO a -> IO a) -> IO b) -> IO b
forall (m :: * -> *) b.
(MonadMask m, HasCallStack) =>
((forall a. m a -> m a) -> m b) -> m b
MC.mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restoreInner -> do
            -- fdReadByte is a safe-FFI read(2), interrupted by
            -- fdShutdown from killChildren, not by throwTo.
            msg <- HasCallStack => Fd -> IO Word8
Fd -> IO Word8
fdReadByte Fd
fd
            case msg of
              Word8
CmdWait -> do
                -- Block until a token is available.
                -- restoreInner keeps retry interruptible under mask.
                IO () -> IO ()
forall a. IO a -> IO a
restoreInner (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    n <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
pool
                    when (n <= 0) retry
                    writeTVar pool (n - 1)
                    modifyTVar' myCount (+ 1)
                HasCallStack => Fd -> Word8 -> IO ()
Fd -> Word8 -> IO ()
fdWriteByte Fd
fd Word8
RspOk

              Word8
CmdRelease -> do
                ok <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
                    mc <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
myCount
                    if mc > 0
                      then do
                        modifyTVar' pool (+ 1)
                        modifyTVar' myCount (subtract 1)
                        return True
                      else return False
                fdWriteByte fd (if ok then RspOk else RspFail)

              Word8
CmdTryWait -> do
                acquired <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
                    n <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
pool
                    if n > 0
                      then do
                        writeTVar pool (n - 1)
                        modifyTVar' myCount (+ 1)
                        return True
                      else return False
                if acquired
                  then fdWriteByte fd RspOk
                  else fdWriteByte fd RspFail

              -- Unknown command: reply so the client doesn't hang in read().
              Word8
_ -> HasCallStack => Fd -> Word8 -> IO ()
Fd -> Word8 -> IO ()
fdWriteByte Fd
fd Word8
RspFail

        cleanup = do
          -- Return tokens to the pool and remove from children list.
          STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            n <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
myCount
            when (n > 0) $ modifyTVar' pool (+ n)
            modifyTVar' children (filter (\Child
c -> Child -> MVar Fd
childFdLock Child
c MVar Fd -> MVar Fd -> Bool
forall a. Eq a => a -> a -> Bool
/= MVar Fd
fdLock))
          -- Take fd ownership and close.
          -- prevents killChildren from double closing fd
          IO Fd -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Fd -> IO ()) -> IO Fd -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar Fd -> IO Fd
forall a. MVar a -> IO a
takeMVar MVar Fd
fdLock
          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> IO (Either IOException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
closeFd Fd
fd

    -- restore so thread can be killed in between loop iterations
    -- Catch IOException (EOF/disconnect) silently.
    (restore loop `MC.catch` \(IOException
_ :: IOException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      `MC.finally` cleanup