diff --git a/Control/Monad/Writer/Class.hs b/Control/Monad/Writer/Class.hs index 11c156a..74e36aa 100644 --- a/Control/Monad/Writer/Class.hs +++ b/Control/Monad/Writer/Class.hs @@ -48,6 +48,7 @@ import qualified Control.Monad.Trans.Accum as Accum import qualified Control.Monad.Trans.RWS.CPS as CPSRWS import qualified Control.Monad.Trans.Writer.CPS as CPS import Control.Monad.Trans.Class (lift) +import Data.Tuple (swap) -- --------------------------------------------------------------------------- -- MonadWriter class @@ -62,13 +63,11 @@ import Control.Monad.Trans.Class (lift) -- the written object. class (Monoid w, Monad m) => MonadWriter w m | m -> w where - {-# MINIMAL (writer | tell), listen, pass #-} + {-# MINIMAL listen, pass | overwrite #-} -- | @'writer' (a,w)@ embeds a simple writer action. writer :: (a,w) -> m a - writer ~(a, w) = do - tell w - return a - + writer ~(a, w) = pass (pure (a, const w)) + {-# INLINE writer #-} -- | @'tell' w@ is an action that produces the output @w@. tell :: w -> m () tell w = writer ((),w) @@ -76,10 +75,26 @@ class (Monoid w, Monad m) => MonadWriter w m | m -> w where -- | @'listen' m@ is an action that executes the action @m@ and adds -- its output to the value of the computation. listen :: m a -> m (a, w) + listen = overwrite (\ ~(w, a) -> (w, (a, w))) + {-# INLINE listen #-} + -- | @'pass' m@ is an action that executes the action @m@, which -- returns a value and a function, and returns the value, applying -- the function to the output. pass :: m (a, w -> w) -> m a + pass = overwrite (\ ~(w, ~(a, f)) -> (f w, a)) + {-# INLINE pass #-} + + -- | @'overwrite'@ uses a function to simultaneously modify the + -- output and the return value of a writer action. + overwrite :: ((w, a) -> (w, b)) -> m a -> m b + overwrite f = + pass . fmap f' . listen + where + f' ~(a, w) = (b, const w') + where + ~(w', b) = f (w, a) + {-# INLINE overwrite #-} -- | @'listens' f m@ is an action that executes the action @m@ and adds -- the result of applying @f@ to the output to the value of the computation. @@ -106,6 +121,7 @@ instance (Monoid w) => MonadWriter w ((,) w) where tell w = (w, ()) listen ~(w, a) = (w, (a, w)) pass ~(w, (a, f)) = (f w, a) + overwrite f = f -- | @since 2.3 instance (Monoid w, Monad m) => MonadWriter w (CPS.WriterT w m) where @@ -113,18 +129,21 @@ instance (Monoid w, Monad m) => MonadWriter w (CPS.WriterT w m) where tell = CPS.tell listen = CPS.listen pass = CPS.pass + overwrite f = CPS.mapWriterT (fmap (swap . f . swap)) instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where writer = Lazy.writer tell = Lazy.tell listen = Lazy.listen pass = Lazy.pass + overwrite f = Lazy.mapWriterT (fmap (swap . f . swap)) instance (Monoid w, Monad m) => MonadWriter w (Strict.WriterT w m) where writer = Strict.writer tell = Strict.tell listen = Strict.listen pass = Strict.pass + overwrite f = Strict.mapWriterT (fmap (swap . f . swap)) -- | @since 2.3 instance (Monoid w, Monad m) => MonadWriter w (CPSRWS.RWST r w s m) where @@ -163,6 +182,7 @@ instance MonadWriter w m => MonadWriter w (IdentityT m) where tell = lift . tell listen = Identity.mapIdentityT listen pass = Identity.mapIdentityT pass + overwrite f = Identity.mapIdentityT (overwrite f) instance MonadWriter w m => MonadWriter w (MaybeT m) where writer = lift . writer @@ -175,6 +195,7 @@ instance MonadWriter w m => MonadWriter w (ReaderT r m) where tell = lift . tell listen = mapReaderT listen pass = mapReaderT pass + overwrite f = mapReaderT (overwrite f) instance MonadWriter w m => MonadWriter w (Lazy.StateT s m) where writer = lift . writer