{-# LANGUAGE ForeignFunctionInterface #-}

module System.Semaphore.Internal.DomainSocket
  ( connectDomainSocket
  , listenDomainSocket
  , pollAcceptSocket, AcceptResult(..)
  , fdReadByte, fdWriteByte
  , fdShutdown
  ) where

-- base
import Data.Word ( Word8 )
import Foreign.C.Error ( throwErrnoIfMinus1Retry, throwErrnoIfMinus1Retry_, throwErrno )
import Foreign.C.String ( CString, withCString )
import Foreign.C.Types ( CInt(..), CSize(..) )
import Foreign.Marshal.Alloc ( allocaBytes )
import Foreign.Ptr ( Ptr )
import Foreign.Storable ( peek, poke )
import GHC.IO.Exception ( IOErrorType(EOF), IOException(..) )
import GHC.Stack ( HasCallStack, callStack, prettyCallStack )

-- unix
import System.Posix.Types ( Fd(..) )

foreign import ccall safe "hs_connect_domain_socket"
    c_connectDomainSocket :: CString -> IO CInt

connectDomainSocket :: FilePath -> IO Fd
connectDomainSocket :: FilePath -> IO Fd
connectDomainSocket FilePath
path =
    FilePath -> (CString -> IO Fd) -> IO Fd
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
path ((CString -> IO Fd) -> IO Fd) -> (CString -> IO Fd) -> IO Fd
forall a b. (a -> b) -> a -> b
$ (CInt -> Fd) -> IO CInt -> IO Fd
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Fd
Fd (IO CInt -> IO Fd) -> (CString -> IO CInt) -> CString -> IO Fd
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => FilePath -> IO a -> IO a
throwErrnoIfMinus1Retry FilePath
"connectDomainSocket" (IO CInt -> IO CInt) -> (CString -> IO CInt) -> CString -> IO CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CString -> IO CInt
c_connectDomainSocket

foreign import ccall safe "hs_listen_domain_socket"
    c_listenDomainSocket :: CString -> IO CInt

-- | Open a socket in non blocking mode (O_NONBLOCK)
listenDomainSocket :: FilePath -> IO Fd
listenDomainSocket :: FilePath -> IO Fd
listenDomainSocket FilePath
path =
    FilePath -> (CString -> IO Fd) -> IO Fd
forall a. FilePath -> (CString -> IO a) -> IO a
withCString FilePath
path ((CString -> IO Fd) -> IO Fd) -> (CString -> IO Fd) -> IO Fd
forall a b. (a -> b) -> a -> b
$ (CInt -> Fd) -> IO CInt -> IO Fd
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Fd
Fd (IO CInt -> IO Fd) -> (CString -> IO CInt) -> CString -> IO Fd
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => FilePath -> IO a -> IO a
throwErrnoIfMinus1Retry FilePath
"listenDomainSocket" (IO CInt -> IO CInt) -> (CString -> IO CInt) -> CString -> IO CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CString -> IO CInt
c_listenDomainSocket

foreign import ccall safe "read"
    c_read :: CInt -> Ptr Word8 -> CSize -> IO CInt

foreign import ccall safe "write"
    c_write :: CInt -> Ptr Word8 -> CSize -> IO CInt

-- | Read a single byte from a file descriptor.
-- Throws an EOF 'IOError' if the peer has disconnected.
fdReadByte :: HasCallStack => Fd -> IO Word8
fdReadByte :: HasCallStack => Fd -> IO Word8
fdReadByte (Fd CInt
fd) =
    Int -> (Ptr Word8 -> IO Word8) -> IO Word8
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
1 ((Ptr Word8 -> IO Word8) -> IO Word8)
-> (Ptr Word8 -> IO Word8) -> IO Word8
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
        rc <- FilePath -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => FilePath -> IO a -> IO a
throwErrnoIfMinus1Retry (FilePath
"fdReadByte(fd=" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ CInt -> FilePath
forall a. Show a => a -> FilePath
show CInt
fd FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
")") (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$
            CInt -> Ptr Word8 -> CSize -> IO CInt
c_read CInt
fd Ptr Word8
buf CSize
1
        if rc == 0
          then ioError $ IOError Nothing EOF
                  (prettyCallStack callStack)
                  ("fd=" ++ show fd)
                  Nothing Nothing
          else peek buf

-- | Write a single byte to a file descriptor.
fdWriteByte :: HasCallStack => Fd -> Word8 -> IO ()
fdWriteByte :: HasCallStack => Fd -> Word8 -> IO ()
fdWriteByte (Fd CInt
fd) Word8
byte =
    Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
1 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
buf -> do
        Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
buf Word8
byte
        _ <- FilePath -> IO CInt -> IO CInt
forall a. (Eq a, Num a) => FilePath -> IO a -> IO a
throwErrnoIfMinus1Retry (FilePath
"fdWriteByte(fd=" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ CInt -> FilePath
forall a. Show a => a -> FilePath
show CInt
fd FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
")") (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$
            CInt -> Ptr Word8 -> CSize -> IO CInt
c_write CInt
fd Ptr Word8
buf CSize
1
        return ()

foreign import ccall safe "shutdown"
    c_shutdown :: CInt -> CInt -> IO CInt

-- | Shut down a socket for both reading and writing.
-- A concurrent 'fdReadByte' on the same fd will return immediately
-- Used to cancel threads blocked in read()
fdShutdown :: Fd -> IO ()
fdShutdown :: Fd -> IO ()
fdShutdown (Fd CInt
fd) =
    FilePath -> IO CInt -> IO ()
forall a. (Eq a, Num a) => FilePath -> IO a -> IO ()
throwErrnoIfMinus1Retry_ FilePath
"fdShutdown" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ CInt -> CInt -> IO CInt
c_shutdown CInt
fd CInt
2  -- SHUT_RDWR

-- | Result of 'pollAcceptSocket'.
data AcceptResult
  = AcceptedFd !Fd     -- ^ A client connected; here is the fd.
  | AcceptCancelled    -- ^ The cancel pipe was signalled.

foreign import ccall safe "hs_poll_accept"
    c_pollAccept :: CInt -> CInt -> IO CInt

-- | Block until either a client connects or the cancel fd is written to.
--
-- Relies on cooperative cancellation implemented in hs_poll_accept using
-- @poll(2)@ + @accept(2)@ via safe FFI to avoid GHC #27110 and #27113.
--
-- Must be called from a masked context.  The caller is responsible for
-- installing an exception handler that closes all 3 fds (inputs and outputs).
pollAcceptSocket :: Fd -> Fd -> IO AcceptResult
pollAcceptSocket :: Fd -> Fd -> IO AcceptResult
pollAcceptSocket (Fd CInt
listenFd) (Fd CInt
cancelFd) = do
    r <- CInt -> CInt -> IO CInt
c_pollAccept CInt
listenFd CInt
cancelFd
    if r == -2
      then return AcceptCancelled
      else if r == -1
        then throwErrno "pollAcceptSocket"
        else return (AcceptedFd (Fd r))