Skip to content

Commit 8d0912a

Browse files
cootdcoutts
andcommitted
Added MonadMVar
Including a default implementation using 'MonadSTM', which guarantees fairness. Co-authored-by: Duncan Coutts <[email protected]>
1 parent 57e888b commit 8d0912a

File tree

3 files changed

+396
-0
lines changed

3 files changed

+396
-0
lines changed

io-classes/io-classes.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ library
3737
Control.Monad.Class.MonadAsync
3838
Control.Monad.Class.MonadEventlog
3939
Control.Monad.Class.MonadFork
40+
Control.Monad.Class.MonadMVar
4041
Control.Monad.Class.MonadSay
4142
Control.Monad.Class.MonadST
4243
Control.Monad.Class.MonadSTM
@@ -57,6 +58,7 @@ library
5758
build-depends: base >=4.9 && <4.17,
5859
async >=2.1,
5960
bytestring,
61+
deque,
6062
mtl >=2.2 && <2.3,
6163
stm >=2.5 && <2.6,
6264
time >=1.9.1 && <1.11
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
{-# LANGUAGE DefaultSignatures #-}
2+
{-# LANGUAGE QuantifiedConstraints #-}
3+
{-# LANGUAGE TypeFamilies #-}
4+
{-# LANGUAGE TypeFamilyDependencies #-}
5+
6+
module Control.Monad.Class.MonadMVar
7+
( MonadMVar (..)
8+
, MVarDefault
9+
, newEmptyMVarDefault
10+
, newMVarDefault
11+
, putMVarDefault
12+
, takeMVarDefault
13+
, readMVarDefault
14+
, tryTakeMVarDefault
15+
, tryPutMVarDefault
16+
, isEmptyMVarDefault
17+
) where
18+
19+
import qualified Control.Concurrent.MVar as IO
20+
import Control.Exception (SomeAsyncException (..))
21+
import Control.Monad.Class.MonadSTM
22+
import Control.Monad.Class.MonadThrow
23+
24+
import Control.Monad.Reader (ReaderT (..))
25+
import Control.Monad.Trans (lift)
26+
27+
import Data.Kind (Type)
28+
import Deque.Strict (Deque)
29+
import qualified Deque.Strict as Deque
30+
31+
32+
class ( Monad m
33+
, forall a tvar. tvar ~ TVar m a => Eq tvar
34+
)
35+
=> MonadMVar m where
36+
{-# MINIMAL newEmptyMVar, takeMVar, putMVar, tryTakeMVar, tryPutMVar, isEmptyMVar #-}
37+
38+
type MVar m = (mvar :: Type -> Type) | mvar -> m
39+
40+
newEmptyMVar :: m (MVar m a)
41+
takeMVar :: MVar m a -> m a
42+
putMVar :: MVar m a -> a -> m ()
43+
tryTakeMVar :: MVar m a -> m (Maybe a)
44+
tryPutMVar :: MVar m a -> a -> m Bool
45+
isEmptyMVar :: MVar m a -> m Bool
46+
47+
-- methods with a default implementation
48+
newMVar :: a -> m (MVar m a)
49+
readMVar :: MVar m a -> m a
50+
swapMVar :: MVar m a -> a -> m a
51+
withMVar :: MVar m a -> (a -> m b) -> m b
52+
withMVarMasked :: MVar m a -> (a -> m b) -> m b
53+
modifyMVar_ :: MVar m a -> (a -> m a) -> m ()
54+
modifyMVar :: MVar m a -> (a -> m (a, b)) -> m b
55+
modifyMVarMasked_ :: MVar m a -> (a -> m a) -> m ()
56+
modifyMVarMasked :: MVar m a -> (a -> m (a,b)) -> m b
57+
58+
default newMVar :: a -> m (MVar m a)
59+
newMVar a = do
60+
v <- newEmptyMVar
61+
putMVar v a
62+
return v
63+
{-# INLINE newMVar #-}
64+
65+
default readMVar :: MVar m a -> m a
66+
readMVar v = do
67+
a <- takeMVar v
68+
putMVar v a
69+
return a
70+
{-# INLINE readMVar #-}
71+
72+
default swapMVar :: MonadMask m => MVar m a -> a -> m a
73+
swapMVar mvar new =
74+
mask_ $ do
75+
old <- takeMVar mvar
76+
putMVar mvar new
77+
return old
78+
{-# INLINE swapMVar #-}
79+
80+
default withMVar :: MonadMask m => MVar m a -> (a -> m b) -> m b
81+
withMVar m io =
82+
mask $ \restore -> do
83+
a <- takeMVar m
84+
b <- restore (io a) `onException` putMVar m a
85+
putMVar m a
86+
return b
87+
{-# INLINE withMVar #-}
88+
89+
default withMVarMasked :: MonadMask m => MVar m a -> (a -> m b) -> m b
90+
withMVarMasked m io =
91+
mask_ $ do
92+
a <- takeMVar m
93+
b <- io a `onException` putMVar m a
94+
putMVar m a
95+
return b
96+
{-# INLINE withMVarMasked #-}
97+
98+
default modifyMVar_ :: MonadMask m => MVar m a -> (a -> m a) -> m ()
99+
modifyMVar_ m io =
100+
mask $ \restore -> do
101+
a <- takeMVar m
102+
a' <- restore (io a) `onException` putMVar m a
103+
putMVar m a'
104+
{-# INLINE modifyMVar_ #-}
105+
106+
default modifyMVar :: (MonadMask m, MonadEvaluate m)
107+
=> MVar m a -> (a -> m (a,b)) -> m b
108+
modifyMVar m io =
109+
mask $ \restore -> do
110+
a <- takeMVar m
111+
(a',b) <- restore (io a >>= evaluate) `onException` putMVar m a
112+
putMVar m a'
113+
return b
114+
{-# INLINE modifyMVar #-}
115+
116+
default modifyMVarMasked_ :: MonadMask m => MVar m a -> (a -> m a) -> m ()
117+
modifyMVarMasked_ m io =
118+
mask_ $ do
119+
a <- takeMVar m
120+
a' <- io a `onException` putMVar m a
121+
putMVar m a'
122+
{-# INLINE modifyMVarMasked_ #-}
123+
124+
default modifyMVarMasked :: (MonadMask m, MonadEvaluate m)
125+
=> MVar m a -> (a -> m (a,b)) -> m b
126+
modifyMVarMasked m io =
127+
mask_ $ do
128+
a <- takeMVar m
129+
(a',b) <- (io a >>= evaluate) `onException` putMVar m a
130+
putMVar m a'
131+
return b
132+
{-# INLINE modifyMVarMasked #-}
133+
134+
135+
instance MonadMVar IO where
136+
type MVar IO = IO.MVar
137+
newEmptyMVar = IO.newEmptyMVar
138+
newMVar = IO.newMVar
139+
takeMVar = IO.takeMVar
140+
putMVar = IO.putMVar
141+
readMVar = IO.readMVar
142+
swapMVar = IO.swapMVar
143+
tryTakeMVar = IO.tryTakeMVar
144+
tryPutMVar = IO.tryPutMVar
145+
isEmptyMVar = IO.isEmptyMVar
146+
withMVar = IO.withMVar
147+
withMVarMasked = IO.withMVarMasked
148+
modifyMVar_ = IO.modifyMVar_
149+
modifyMVar = IO.modifyMVar
150+
modifyMVarMasked_ = IO.modifyMVarMasked_
151+
modifyMVarMasked = IO.modifyMVarMasked
152+
153+
154+
data MVarState m a = MVarEmpty !(Deque (TVar m (Maybe a))) -- ^ blocked on take
155+
| MVarFull a !(Deque (a, TVar m Bool)) -- ^ blocked on put
156+
157+
-- | A default 'MVar' implementation based on `TVar`'s. An 'MVar' provides
158+
-- fairness guarantees.
159+
--
160+
newtype MVarDefault m a = MVar (TVar m (MVarState m a))
161+
162+
163+
newEmptyMVarDefault :: MonadSTM m => m (MVarDefault m a)
164+
newEmptyMVarDefault = MVar <$> newTVarIO (MVarEmpty mempty)
165+
166+
167+
newMVarDefault :: MonadSTM m => a -> m (MVarDefault m a)
168+
newMVarDefault a = MVar <$> newTVarIO (MVarFull a mempty)
169+
170+
171+
putMVarDefault :: ( MonadCatch m
172+
, MonadMask m
173+
, MonadSTM m
174+
, forall x tvar. tvar ~ TVar m x => Eq tvar
175+
)
176+
=> MVarDefault m a -> a -> m ()
177+
putMVarDefault (MVar tv) x = mask_ $ do
178+
res <- atomically $ do
179+
s <- readTVar tv
180+
case s of
181+
-- if it's full we add ourselves to the blocked queue
182+
MVarFull x' blockedq -> do
183+
wakevar <- newTVar False
184+
writeTVar tv (MVarFull x (Deque.snoc (x', wakevar) blockedq))
185+
return (Just wakevar)
186+
187+
-- if it's empty we fill in the value, and also complete the action of
188+
-- the next thread blocked in takeMVar
189+
MVarEmpty blockedq ->
190+
case Deque.uncons blockedq of
191+
Nothing -> do
192+
writeTVar tv (MVarFull x mempty)
193+
return Nothing
194+
195+
Just (wakevar, blockedq') -> do
196+
writeTVar wakevar (Just x)
197+
writeTVar tv (MVarEmpty blockedq')
198+
return Nothing
199+
200+
case res of
201+
-- we have to block on our own wakevar until we can complete the put
202+
Just wakevar ->
203+
atomically (readTVar wakevar >>= check)
204+
`catch` \e@SomeAsyncException {} -> do
205+
atomically $ do
206+
s <- readTVar tv
207+
case s of
208+
MVarFull x' blockedq -> do
209+
-- async exception was thrown while we were blocked on wakevar;
210+
-- we need to remove it from the queue, otherwise we will have
211+
-- a space leak.
212+
let blockedq' = Deque.filter ((/= wakevar) . snd) blockedq
213+
writeTVar tv (MVarFull x' blockedq')
214+
-- the exception was thrown when we were blocked on 'waketvar', so
215+
-- the 'MVar' must not be empty.
216+
MVarEmpty {} -> error "putMVarDefault: invariant violation"
217+
throwIO e
218+
219+
-- we managed to do the put synchronously
220+
Nothing -> return ()
221+
222+
223+
takeMVarDefault :: ( MonadMask m
224+
, MonadSTM m
225+
, forall x tvar. tvar ~ TVar m x => Eq tvar
226+
)
227+
=> MVarDefault m a
228+
-> m a
229+
takeMVarDefault (MVar tv) = mask_ $ do
230+
res <- atomically $ do
231+
s <- readTVar tv
232+
case s of
233+
-- if it's empty we add ourselves to the blocked queue
234+
MVarEmpty blockedq -> do
235+
wakevar <- newTVar Nothing
236+
writeTVar tv (MVarEmpty (Deque.snoc wakevar blockedq))
237+
return (Left wakevar)
238+
239+
-- if it's full we grab the value, and also complete the action of the
240+
-- next thread blocked in putMVar, by setting the new MVar value and
241+
-- unblocking them.
242+
MVarFull x blockedq ->
243+
case Deque.uncons blockedq of
244+
Nothing ->
245+
return (Right x)
246+
247+
Just ((x', wakevar), blockedq') -> do
248+
writeTVar wakevar True
249+
writeTVar tv (MVarFull x' blockedq')
250+
return (Right x)
251+
252+
case res of
253+
-- we have to block on our own wakevar until we can complete the read
254+
Left wakevar ->
255+
atomically (readTVar wakevar >>= maybe retry return)
256+
`catch` \e@SomeAsyncException {} -> do
257+
atomically $ do
258+
s <- readTVar tv
259+
case s of
260+
MVarEmpty blockedq -> do
261+
-- async exception was thrown while were were blocked on
262+
-- wakevar; we need to remove it from 'blockedq', otherwise we
263+
-- will have a space leak.
264+
let blockedq' = Deque.filter (/= wakevar) blockedq
265+
writeTVar tv (MVarEmpty blockedq')
266+
-- the exception was thrown while we were blocked on 'wakevar', so
267+
-- the 'MVar' must not be full.
268+
MVarFull {} -> error "takeMVarDefault: invariant violation"
269+
throwIO e
270+
271+
-- we managed to do the take synchronously
272+
Right x -> return x
273+
274+
275+
-- | 'readMVarDefault' when the 'MVar' is empty, guarantees to receive next
276+
-- 'putMVar' value. It will also not block if the 'MVar' is full, even if there
277+
-- are other threads attempting to 'putMVar'.
278+
--
279+
readMVarDefault :: MonadSTM m
280+
=> MVarDefault m a
281+
-> m a
282+
readMVarDefault (MVar tv) = do
283+
atomically $ do
284+
s <- readTVar tv
285+
case s of
286+
-- if it's empty block
287+
MVarEmpty _ -> retry
288+
289+
-- if it's full return the value
290+
MVarFull x _ -> return x
291+
292+
293+
tryTakeMVarDefault :: MonadSTM m
294+
=> MVarDefault m a
295+
-> m (Maybe a)
296+
tryTakeMVarDefault (MVar tv) = do
297+
atomically $ do
298+
s <- readTVar tv
299+
case s of
300+
MVarEmpty _ -> return Nothing
301+
MVarFull x blockedq ->
302+
case Deque.uncons blockedq of
303+
Nothing -> return (Just x)
304+
Just ((x', wakevar), blockedq') -> do
305+
writeTVar wakevar True
306+
writeTVar tv (MVarFull x' blockedq')
307+
return (Just x)
308+
309+
310+
tryPutMVarDefault :: MonadSTM m
311+
=> MVarDefault m a -> a -> m Bool
312+
tryPutMVarDefault (MVar tv) x =
313+
atomically $ do
314+
s <- readTVar tv
315+
case s of
316+
MVarFull {} -> return False
317+
318+
MVarEmpty blockedq ->
319+
case Deque.uncons blockedq of
320+
Nothing -> do
321+
writeTVar tv (MVarFull x mempty)
322+
return True
323+
324+
Just (wakevar, blockedq') -> do
325+
writeTVar wakevar (Just x)
326+
writeTVar tv (MVarEmpty blockedq')
327+
return True
328+
329+
330+
isEmptyMVarDefault :: MonadSTM m
331+
=> MVarDefault m a -> m Bool
332+
isEmptyMVarDefault (MVar tv) =
333+
atomically $ do
334+
s <- readTVar tv
335+
case s of
336+
MVarFull {} -> return False
337+
MVarEmpty blockedq | null blockedq -> return True
338+
| otherwise -> error "isEmptyMVarDefault: invariant violation"
339+
340+
341+
--
342+
-- ReaderT instance
343+
--
344+
345+
newtype WrappedMVar r (m :: Type -> Type) a = WrappedMVar { unwrapMVar :: MVar m a }
346+
347+
instance ( MonadMask m
348+
, MonadMVar m
349+
, MonadEvaluate m
350+
) => MonadMVar (ReaderT r m) where
351+
type MVar (ReaderT r m) = WrappedMVar r m
352+
newEmptyMVar = WrappedMVar <$> lift newEmptyMVar
353+
newMVar = fmap WrappedMVar . lift . newMVar
354+
takeMVar = lift . takeMVar . unwrapMVar
355+
putMVar = lift .: (putMVar . unwrapMVar)
356+
readMVar = lift . readMVar . unwrapMVar
357+
swapMVar = lift .: (swapMVar . unwrapMVar)
358+
tryTakeMVar = lift . tryTakeMVar . unwrapMVar
359+
tryPutMVar = lift .: (tryPutMVar . unwrapMVar)
360+
isEmptyMVar = lift . isEmptyMVar . unwrapMVar
361+
withMVar (WrappedMVar v) f = ReaderT $ \r ->
362+
withMVar v (\a -> runReaderT (f a) r)
363+
withMVarMasked (WrappedMVar v) f = ReaderT $ \r ->
364+
withMVarMasked v (\a -> runReaderT (f a) r)
365+
modifyMVar_ (WrappedMVar v) f = ReaderT $ \r ->
366+
modifyMVar_ v (\a -> runReaderT (f a) r)
367+
modifyMVar (WrappedMVar v) f = ReaderT $ \r ->
368+
modifyMVar v (\a -> runReaderT (f a) r)
369+
modifyMVarMasked_ (WrappedMVar v) f = ReaderT $ \r ->
370+
modifyMVarMasked_ v (\a -> runReaderT (f a) r)
371+
modifyMVarMasked (WrappedMVar v) f = ReaderT $ \r ->
372+
modifyMVarMasked v (\a -> runReaderT (f a) r)
373+
374+
375+
376+
377+
--
378+
-- Utilities
379+
--
380+
381+
(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
382+
(f .: g) x y = f (g x y)

0 commit comments

Comments
 (0)