module Control.Monad.Trans.List (
ListT(..),
mapListT,
liftCallCC,
liftCatch,
) where
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Applicative
import Control.Monad
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse))
newtype ListT m a = ListT { runListT :: m [a] }
mapListT :: (m [a] -> n [b]) -> ListT m a -> ListT n b
mapListT f m = ListT $ f (runListT m)
instance (Functor m) => Functor (ListT m) where
fmap f = mapListT $ fmap $ map f
instance Foldable f => Foldable (ListT f) where
foldMap f (ListT a) = foldMap (foldMap f) a
instance Traversable f => Traversable (ListT f) where
traverse f (ListT a) = ListT <$> traverse (traverse f) a
instance (Applicative m) => Applicative (ListT m) where
pure a = ListT $ pure [a]
f <*> v = ListT $ (<*>) <$> runListT f <*> runListT v
instance (Applicative m) => Alternative (ListT m) where
empty = ListT $ pure []
m <|> n = ListT $ (++) <$> runListT m <*> runListT n
instance (Monad m) => Monad (ListT m) where
return a = ListT $ return [a]
m >>= k = ListT $ do
a <- runListT m
b <- mapM (runListT . k) a
return (concat b)
fail _ = ListT $ return []
instance (Monad m) => MonadPlus (ListT m) where
mzero = ListT $ return []
m `mplus` n = ListT $ do
a <- runListT m
b <- runListT n
return (a ++ b)
instance MonadTrans ListT where
lift m = ListT $ do
a <- m
return [a]
instance (MonadIO m) => MonadIO (ListT m) where
liftIO = lift . liftIO
liftCallCC :: ((([a] -> m [b]) -> m [a]) -> m [a]) ->
((a -> ListT m b) -> ListT m a) -> ListT m a
liftCallCC callCC f = ListT $
callCC $ \c ->
runListT (f (\a -> ListT $ c [a]))
liftCatch :: (m [a] -> (e -> m [a]) -> m [a]) ->
ListT m a -> (e -> ListT m a) -> ListT m a
liftCatch catchError m h = ListT $ runListT m
`catchError` \e -> runListT (h e)