{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}

-- | 'Async', yet using 'MVar's.
--
-- Adopted from @async@ library
-- Copyright (c) 2012, Simon Marlow
-- Licensed under BSD-3-Clause
--
-- @since 3.2.0.0
module Distribution.Compat.Async
  ( AsyncM
  , withAsync
  , waitCatch
  , wait
  , asyncThreadId
  , cancel
  , uninterruptibleCancel
  , AsyncCancelled (..)

    -- * Cabal extras
  , withAsyncNF
  ) where

import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, readMVar)
import Control.DeepSeq (NFData, force)
import Control.Exception
  ( BlockedIndefinitelyOnMVar (..)
  , Exception (..)
  , SomeException (..)
  , asyncExceptionFromException
  , asyncExceptionToException
  , catch
  , evaluate
  , mask
  , throwIO
  , throwTo
  , try
  , uninterruptibleMask_
  )
import Control.Monad (void)
import Data.Typeable (Typeable)
import GHC.Exts (inline)

-- | Async, but based on 'MVar', as we don't depend on @stm@.
data AsyncM a = Async
  { forall a. AsyncM a -> ThreadId
asyncThreadId :: {-# UNPACK #-} !ThreadId
  -- ^ Returns the 'ThreadId' of the thread running
  -- the given 'Async'.
  , forall a. AsyncM a -> MVar (Either SomeException a)
_asyncMVar :: MVar (Either SomeException a)
  }

-- | Spawn an asynchronous action in a separate thread, and pass its
-- @Async@ handle to the supplied function.  When the function returns
-- or throws an exception, 'uninterruptibleCancel' is called on the @Async@.
--
-- > withAsync action inner = mask $ \restore -> do
-- >   a <- async (restore action)
-- >   restore (inner a) `finally` uninterruptibleCancel a
--
-- This is a useful variant of 'async' that ensures an @Async@ is
-- never left running unintentionally.
--
-- Note: a reference to the child thread is kept alive until the call
-- to `withAsync` returns, so nesting many `withAsync` calls requires
-- linear memory.
withAsync :: IO a -> (AsyncM a -> IO b) -> IO b
withAsync :: forall a b. IO a -> (AsyncM a -> IO b) -> IO b
withAsync = ((IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b)
-> (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
forall a. a -> a
inline (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
forall a b.
(IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
withAsyncUsing IO () -> IO ThreadId
forkIO

withAsyncNF :: NFData a => IO a -> (AsyncM a -> IO b) -> IO b
withAsyncNF :: forall a b. NFData a => IO a -> (AsyncM a -> IO b) -> IO b
withAsyncNF IO a
m = ((IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b)
-> (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
forall a. a -> a
inline (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
forall a b.
(IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
withAsyncUsing IO () -> IO ThreadId
forkIO (IO a
m IO a -> (a -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> IO a
evaluateNF)
  where
    evaluateNF :: a -> IO a
evaluateNF = a -> IO a
forall a. a -> IO a
evaluate (a -> IO a) -> (a -> a) -> a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. NFData a => a -> a
force

withAsyncUsing :: (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
-- The bracket version works, but is slow.  We can do better by
-- hand-coding it:
withAsyncUsing :: forall a b.
(IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
withAsyncUsing IO () -> IO ThreadId
doFork = \IO a
action AsyncM a -> IO b
inner -> do
  var <- IO (MVar (Either SomeException a))
forall a. IO (MVar a)
newEmptyMVar
  mask $ \forall a. IO a -> IO a
restore -> do
    t <- IO () -> IO ThreadId
doFork (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO a
forall a. IO a -> IO a
restore IO a
action) IO (Either SomeException a)
-> (Either SomeException a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either SomeException a)
var
    let a = ThreadId -> MVar (Either SomeException a) -> AsyncM a
forall a. ThreadId -> MVar (Either SomeException a) -> AsyncM a
Async ThreadId
t MVar (Either SomeException a)
var
    r <-
      restore (inner a) `catchAll` \SomeException
e -> do
        AsyncM a -> IO ()
forall a. AsyncM a -> IO ()
uninterruptibleCancel AsyncM a
a
        SomeException -> IO b
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
e
    uninterruptibleCancel a
    return r

-- | Wait for an asynchronous action to complete, and return its
-- value.  If the asynchronous action threw an exception, then the
-- exception is re-thrown by 'wait'.
--
-- > wait = atomically . waitSTM
{-# INLINE wait #-}
wait :: AsyncM a -> IO a
wait :: forall a. AsyncM a -> IO a
wait AsyncM a
a = do
  res <- AsyncM a -> IO (Either SomeException a)
forall a. AsyncM a -> IO (Either SomeException a)
waitCatch AsyncM a
a
  case res of
    Left (SomeException e
e) -> e -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO e
e
    Right a
x -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Wait for an asynchronous action to complete, and return either
-- @Left e@ if the action raised an exception @e@, or @Right a@ if it
-- returned a value @a@.
--
-- > waitCatch = atomically . waitCatchSTM
{-# INLINE waitCatch #-}
waitCatch :: AsyncM a -> IO (Either SomeException a)
waitCatch :: forall a. AsyncM a -> IO (Either SomeException a)
waitCatch (Async ThreadId
_ MVar (Either SomeException a)
var) = IO (Either SomeException a) -> IO (Either SomeException a)
forall a. IO a -> IO a
tryAgain (MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
readMVar MVar (Either SomeException a)
var)
  where
    -- See: https://github.com/simonmar/async/issues/14
    tryAgain :: IO a -> IO a
tryAgain IO a
f = IO a
f IO a -> (BlockedIndefinitelyOnMVar -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \BlockedIndefinitelyOnMVar
BlockedIndefinitelyOnMVar -> IO a
f

catchAll :: IO a -> (SomeException -> IO a) -> IO a
catchAll :: forall a. IO a -> (SomeException -> IO a) -> IO a
catchAll = IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch

-- | Cancel an asynchronous action by throwing the @AsyncCancelled@
-- exception to it, and waiting for the `Async` thread to quit.
-- Has no effect if the 'Async' has already completed.
--
-- > cancel a = throwTo (asyncThreadId a) AsyncCancelled <* waitCatch a
--
-- Note that 'cancel' will not terminate until the thread the 'Async'
-- refers to has terminated. This means that 'cancel' will block for
-- as long said thread blocks when receiving an asynchronous exception.
--
-- For example, it could block if:
--
-- * It's executing a foreign call, and thus cannot receive the asynchronous
-- exception;
-- * It's executing some cleanup handler after having received the exception,
-- and the handler is blocking.
{-# INLINE cancel #-}
cancel :: AsyncM a -> IO ()
cancel :: forall a. AsyncM a -> IO ()
cancel a :: AsyncM a
a@(Async ThreadId
t MVar (Either SomeException a)
_) = do
  ThreadId -> AsyncCancelled -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
t AsyncCancelled
AsyncCancelled
  IO (Either SomeException a) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (AsyncM a -> IO (Either SomeException a)
forall a. AsyncM a -> IO (Either SomeException a)
waitCatch AsyncM a
a)

-- | The exception thrown by `cancel` to terminate a thread.
data AsyncCancelled = AsyncCancelled
  deriving
    ( Int -> AsyncCancelled -> ShowS
[AsyncCancelled] -> ShowS
AsyncCancelled -> String
(Int -> AsyncCancelled -> ShowS)
-> (AsyncCancelled -> String)
-> ([AsyncCancelled] -> ShowS)
-> Show AsyncCancelled
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AsyncCancelled -> ShowS
showsPrec :: Int -> AsyncCancelled -> ShowS
$cshow :: AsyncCancelled -> String
show :: AsyncCancelled -> String
$cshowList :: [AsyncCancelled] -> ShowS
showList :: [AsyncCancelled] -> ShowS
Show
    , AsyncCancelled -> AsyncCancelled -> Bool
(AsyncCancelled -> AsyncCancelled -> Bool)
-> (AsyncCancelled -> AsyncCancelled -> Bool) -> Eq AsyncCancelled
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AsyncCancelled -> AsyncCancelled -> Bool
== :: AsyncCancelled -> AsyncCancelled -> Bool
$c/= :: AsyncCancelled -> AsyncCancelled -> Bool
/= :: AsyncCancelled -> AsyncCancelled -> Bool
Eq
    , Typeable
    )

instance Exception AsyncCancelled where
  -- wraps in SomeAsyncException
  -- See https://github.com/ghc/ghc/commit/756a970eacbb6a19230ee3ba57e24999e4157b09
  fromException :: SomeException -> Maybe AsyncCancelled
fromException = SomeException -> Maybe AsyncCancelled
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException
  toException :: AsyncCancelled -> SomeException
toException = AsyncCancelled -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException

-- | Cancel an asynchronous action
--
-- This is a variant of `cancel`, but it is not interruptible.
{-# INLINE uninterruptibleCancel #-}
uninterruptibleCancel :: AsyncM a -> IO ()
uninterruptibleCancel :: forall a. AsyncM a -> IO ()
uninterruptibleCancel = IO () -> IO ()
forall a. IO a -> IO a
uninterruptibleMask_ (IO () -> IO ()) -> (AsyncM a -> IO ()) -> AsyncM a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AsyncM a -> IO ()
forall a. AsyncM a -> IO ()
cancel