module Control.Monad.Fix (
MonadFix(mfix),
fix
) where
import Data.Either
import Data.Function ( fix )
import Data.Maybe
import Data.Monoid ( Dual(..), Sum(..), Product(..)
, First(..), Last(..), Alt(..), Ap(..) )
import Data.Ord ( Down(..) )
import GHC.Base ( Monad, NonEmpty(..), errorWithoutStackTrace, (.) )
import GHC.Generics
import GHC.List ( head, tail )
import Control.Monad.ST.Imp
import System.IO
class (Monad m) => MonadFix m where
mfix :: (a -> m a) -> m a
instance MonadFix Maybe where
mfix f = let a = f (unJust a) in a
where unJust (Just x) = x
unJust Nothing = errorWithoutStackTrace "mfix Maybe: Nothing"
instance MonadFix [] where
mfix f = case fix (f . head) of
[] -> []
(x:_) -> x : mfix (tail . f)
instance MonadFix NonEmpty where
mfix f = case fix (f . neHead) of
~(x :| _) -> x :| mfix (neTail . f)
where
neHead ~(a :| _) = a
neTail ~(_ :| as) = as
instance MonadFix IO where
mfix = fixIO
instance MonadFix ((->) r) where
mfix f = \ r -> let a = f a r in a
instance MonadFix (Either e) where
mfix f = let a = f (unRight a) in a
where unRight (Right x) = x
unRight (Left _) = errorWithoutStackTrace "mfix Either: Left"
instance MonadFix (ST s) where
mfix = fixST
instance MonadFix Dual where
mfix f = Dual (fix (getDual . f))
instance MonadFix Sum where
mfix f = Sum (fix (getSum . f))
instance MonadFix Product where
mfix f = Product (fix (getProduct . f))
instance MonadFix First where
mfix f = First (mfix (getFirst . f))
instance MonadFix Last where
mfix f = Last (mfix (getLast . f))
instance MonadFix f => MonadFix (Alt f) where
mfix f = Alt (mfix (getAlt . f))
instance MonadFix f => MonadFix (Ap f) where
mfix f = Ap (mfix (getAp . f))
instance MonadFix Par1 where
mfix f = Par1 (fix (unPar1 . f))
instance MonadFix f => MonadFix (Rec1 f) where
mfix f = Rec1 (mfix (unRec1 . f))
instance MonadFix f => MonadFix (M1 i c f) where
mfix f = M1 (mfix (unM1. f))
instance (MonadFix f, MonadFix g) => MonadFix (f :*: g) where
mfix f = (mfix (fstP . f)) :*: (mfix (sndP . f))
where
fstP (a :*: _) = a
sndP (_ :*: b) = b
instance MonadFix Down where
mfix f = Down (fix (getDown . f))
where getDown (Down x) = x