Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions Control/Monad/Writer/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import qualified Control.Monad.Trans.Accum as Accum
import qualified Control.Monad.Trans.RWS.CPS as CPSRWS
import qualified Control.Monad.Trans.Writer.CPS as CPS
import Control.Monad.Trans.Class (lift)
import Data.Tuple (swap)

-- ---------------------------------------------------------------------------
-- MonadWriter class
Expand All @@ -62,24 +63,38 @@ import Control.Monad.Trans.Class (lift)
-- the written object.

class (Monoid w, Monad m) => MonadWriter w m | m -> w where
{-# MINIMAL (writer | tell), listen, pass #-}
{-# MINIMAL listen, pass | overwrite #-}
-- | @'writer' (a,w)@ embeds a simple writer action.
writer :: (a,w) -> m a
writer ~(a, w) = do
tell w
return a

writer ~(a, w) = pass (pure (a, const w))
{-# INLINE writer #-}
-- | @'tell' w@ is an action that produces the output @w@.
tell :: w -> m ()
tell w = writer ((),w)

-- | @'listen' m@ is an action that executes the action @m@ and adds
-- its output to the value of the computation.
listen :: m a -> m (a, w)
listen = overwrite (\ ~(w, a) -> (w, (a, w)))
{-# INLINE listen #-}

-- | @'pass' m@ is an action that executes the action @m@, which
-- returns a value and a function, and returns the value, applying
-- the function to the output.
pass :: m (a, w -> w) -> m a
pass = overwrite (\ ~(w, ~(a, f)) -> (f w, a))
{-# INLINE pass #-}

-- | @'overwrite'@ uses a function to simultaneously modify the
-- output and the return value of a writer action.
overwrite :: ((w, a) -> (w, b)) -> m a -> m b
overwrite f =
pass . fmap f' . listen
where
f' ~(a, w) = (b, const w')
where
~(w', b) = f (w, a)
{-# INLINE overwrite #-}

-- | @'listens' f m@ is an action that executes the action @m@ and adds
-- the result of applying @f@ to the output to the value of the computation.
Expand All @@ -106,25 +121,29 @@ instance (Monoid w) => MonadWriter w ((,) w) where
tell w = (w, ())
listen ~(w, a) = (w, (a, w))
pass ~(w, (a, f)) = (f w, a)
overwrite f = f

-- | @since 2.3
instance (Monoid w, Monad m) => MonadWriter w (CPS.WriterT w m) where
writer = CPS.writer
tell = CPS.tell
listen = CPS.listen
pass = CPS.pass
overwrite f = CPS.mapWriterT (fmap (swap . f . swap))

instance (Monoid w, Monad m) => MonadWriter w (Lazy.WriterT w m) where
writer = Lazy.writer
tell = Lazy.tell
listen = Lazy.listen
pass = Lazy.pass
overwrite f = Lazy.mapWriterT (fmap (swap . f . swap))

instance (Monoid w, Monad m) => MonadWriter w (Strict.WriterT w m) where
writer = Strict.writer
tell = Strict.tell
listen = Strict.listen
pass = Strict.pass
overwrite f = Strict.mapWriterT (fmap (swap . f . swap))

-- | @since 2.3
instance (Monoid w, Monad m) => MonadWriter w (CPSRWS.RWST r w s m) where
Expand Down Expand Up @@ -163,6 +182,7 @@ instance MonadWriter w m => MonadWriter w (IdentityT m) where
tell = lift . tell
listen = Identity.mapIdentityT listen
pass = Identity.mapIdentityT pass
overwrite f = Identity.mapIdentityT (overwrite f)

instance MonadWriter w m => MonadWriter w (MaybeT m) where
writer = lift . writer
Expand All @@ -175,6 +195,7 @@ instance MonadWriter w m => MonadWriter w (ReaderT r m) where
tell = lift . tell
listen = mapReaderT listen
pass = mapReaderT pass
overwrite f = mapReaderT (overwrite f)

instance MonadWriter w m => MonadWriter w (Lazy.StateT s m) where
writer = lift . writer
Expand Down