diff --git a/io-classes/src/Control/Monad/Class/MonadAsync.hs b/io-classes/src/Control/Monad/Class/MonadAsync.hs index 9189925b..0abb5507 100644 --- a/io-classes/src/Control/Monad/Class/MonadAsync.hs +++ b/io-classes/src/Control/Monad/Class/MonadAsync.hs @@ -7,8 +7,6 @@ {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilyDependencies #-} -- MonadAsync's ReaderT instance is undecidable. {-# LANGUAGE UndecidableInstances #-} diff --git a/io-classes/src/Control/Monad/Class/MonadSTM.hs b/io-classes/src/Control/Monad/Class/MonadSTM.hs index 34e89788..6ab900a2 100644 --- a/io-classes/src/Control/Monad/Class/MonadSTM.hs +++ b/io-classes/src/Control/Monad/Class/MonadSTM.hs @@ -1,5 +1,15 @@ -- | This module corresponds to `Control.Monad.STM` in "stm" package -- +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +-- undecidable instances needed for 'WrappedSTM' instances of 'MonadThrow' and +-- 'MonadCatch' type classes. +{-# LANGUAGE UndecidableInstances #-} module Control.Monad.Class.MonadSTM ( MonadSTM (STM, atomically, retry, orElse, check) , throwSTM diff --git a/io-classes/src/Control/Monad/Class/MonadTimer.hs b/io-classes/src/Control/Monad/Class/MonadTimer.hs index 54a6c868..2a9bca44 100644 --- a/io-classes/src/Control/Monad/Class/MonadTimer.hs +++ b/io-classes/src/Control/Monad/Class/MonadTimer.hs @@ -54,6 +54,9 @@ class Monad m => MonadDelay m where threadDelay d = void . atomically . awaitTimeout =<< newTimeout d class (MonadSTM m, MonadDelay m) => MonadTimer m where + -- | The type of the timeout handle, used with 'newTimeout', 'readTimeout', + -- 'updateTimeout' and 'cancelTimeout'. + -- data Timeout m :: Type -- | Create a new timeout which will fire at the given time duration in diff --git a/io-sim/bench/Main.hs b/io-sim/bench/Main.hs index dd6a6948..823c7076 100644 --- a/io-sim/bench/Main.hs +++ b/io-sim/bench/Main.hs @@ -4,7 +4,7 @@ module Main (main) where import Control.Concurrent.Class.MonadSTM -import Control.Monad (replicateM) +import Control.Monad (replicateM, forever) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork import Control.Monad.Class.MonadSay @@ -67,6 +67,9 @@ prop_timeout_fail = timeout 1 (threadDelay 2) prop_timeout_succeed :: forall m. MonadTimer m => m (Maybe ()) prop_timeout_succeed = timeout 2 (threadDelay 1) +prop_timeout_race :: forall m. MonadTimer m => m (Maybe ()) +prop_timeout_race = timeout 1 (threadDelay 1) + -- -- threads, async @@ -88,6 +91,13 @@ prop_async n = do ) traverse_ wait threads +prop_threadDelay_bottleneck :: forall m. (MonadTimer m, MonadSay m) + => m (Maybe ()) +prop_threadDelay_bottleneck = + timeout 1000000 $ do + forever $ do + threadDelay 1 + say "" main :: IO () main = defaultMain @@ -117,6 +127,8 @@ main = defaultMain whnf id (runSimOrThrow prop_timeout_fail) , bench "succeed" $ whnf id (runSimOrThrow prop_timeout_succeed) + , bench "race" $ + whnf id (runSimOrThrow prop_timeout_race) ] ] , @@ -127,6 +139,8 @@ main = defaultMain whnf id (runSimOrThrow (prop_async n)) , bench "forkIO silent" $ whnf id (runSimOrThrow (prop_threads n)) + , bench "threadDelay bottleneck silent" $ + whnf id (runSimOrThrow prop_threadDelay_bottleneck) , bench "async say" $ nf id ( selectTraceEventsSay $ runSimTrace @@ -135,6 +149,10 @@ main = defaultMain nf id ( selectTraceEventsSay $ runSimTrace $ prop_threads n) + , bench "threadDelay bottleneck say" $ + nf id ( selectTraceEventsSay + $ runSimTrace + $ prop_threadDelay_bottleneck) ] , env (pure 250) $ \n -> bgroup "250" diff --git a/io-sim/src/Control/Monad/IOSim/CommonTypes.hs b/io-sim/src/Control/Monad/IOSim/CommonTypes.hs index c65f4aa3..093726d2 100644 --- a/io-sim/src/Control/Monad/IOSim/CommonTypes.hs +++ b/io-sim/src/Control/Monad/IOSim/CommonTypes.hs @@ -1,7 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -- | Common types shared between `IOSim` and `IOSimPOR`. -- diff --git a/io-sim/src/Control/Monad/IOSim/Internal.hs b/io-sim/src/Control/Monad/IOSim/Internal.hs index 9dc62d97..eb14197b 100644 --- a/io-sim/src/Control/Monad/IOSim/Internal.hs +++ b/io-sim/src/Control/Monad/IOSim/Internal.hs @@ -1,16 +1,11 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTSyntax #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} @@ -54,7 +49,9 @@ module Control.Monad.IOSim.Internal import Prelude hiding (read) import Data.Dynamic -import Data.Foldable (toList, traverse_) +import Data.Foldable (toList, traverse_, foldlM) +import Deque.Strict (Deque) +import qualified Deque.Strict as Deque import qualified Data.List as List import qualified Data.List.Trace as Trace import Data.Map.Strict (Map) @@ -65,16 +62,13 @@ import qualified Data.OrdPSQ as PSQ import Data.Set (Set) import qualified Data.Set as Set import Data.Time (UTCTime (..), fromGregorian) -import Deque.Strict (Deque) -import qualified Deque.Strict as Deque import GHC.Exts (fromList) import GHC.Conc (ThreadStatus(..), BlockReason(..)) -import Control.Exception (NonTermination (..), assert, throw) -import Control.Monad (join) - -import Control.Monad (when) +import Control.Exception + (NonTermination (..), assert, throw, AsyncException (..)) +import Control.Monad (join, when) import Control.Monad.ST.Lazy import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST) import Data.STRef.Lazy @@ -119,10 +113,14 @@ labelledThreads threadMap = [] threadMap --- | Timers mutable variables. First one supports 'newTimeout' api, the second --- one 'registerDelay'. +-- | Timers mutable variables. Supports 'newTimeout' api, the second +-- one 'registerDelay', the third one 'threadDelay'. -- -data TimerVars s = TimerVars !(TVar s TimeoutState) !(TVar s Bool) +data TimerCompletionInfo s = + Timer !(TVar s TimeoutState) + | TimerRegisterDelay !(TVar s Bool) + | TimerThreadDelay !ThreadId + | TimerTimeout !ThreadId !TimeoutId !(STRef s IsLocked) -- | Internal state. -- @@ -135,8 +133,8 @@ data SimState s a = SimState { finished :: !(Map ThreadId FinishedReason), -- | current time curTime :: !Time, - -- | ordered list of timers - timers :: !(OrdPSQ TimeoutId Time (TimerVars s)), + -- | ordered list of timers and timeouts + timers :: !(OrdPSQ TimeoutId Time (TimerCompletionInfo s)), -- | list of clocks clocks :: !(Map ClockId UTCTime), nextVid :: !TVarId, -- ^ next unused 'TVarId' @@ -158,22 +156,22 @@ initialState = where epoch1970 = UTCTime (fromGregorian 1970 1 1) 0 -invariant :: Maybe (Thread s a) -> SimState s a -> Bool +invariant :: Maybe (Thread s a) -> SimState s a -> x -> x invariant (Just running) simstate@SimState{runqueue,threads,clocks} = - not (threadBlocked running) - && threadId running `Map.notMember` threads - && threadId running `List.notElem` toList runqueue - && threadClockId running `Map.member` clocks - && invariant Nothing simstate + assert (not (threadBlocked running)) + . assert (threadId running `Map.notMember` threads) + . assert (threadId running `List.notElem` runqueue) + . assert (threadClockId running `Map.member` clocks) + . invariant Nothing simstate invariant Nothing SimState{runqueue,threads,clocks} = - all (`Map.member` threads) runqueue - && and [ threadBlocked t == (threadId t `notElem` runqueue) - | t <- Map.elems threads ] - && toList runqueue == List.nub (toList runqueue) - && and [ threadClockId t `Map.member` clocks - | t <- Map.elems threads ] + assert (all (`Map.member` threads) runqueue) + . assert (and [ threadBlocked t == (threadId t `notElem` runqueue) + | t <- Map.elems threads ]) + . assert (toList runqueue == List.nub (toList runqueue)) + . assert (and [ threadClockId t `Map.member` clocks + | t <- Map.elems threads ]) -- | Interpret the simulation monotonic time as a 'NominalDiffTime' since -- the start. @@ -199,7 +197,7 @@ schedule !thread@Thread{ nextVid, nextTmid, curTime = time } = - assert (invariant (Just thread) simstate) $ + invariant (Just thread) simstate $ case action of Return x -> {-# SCC "schedule.Return" #-} @@ -232,8 +230,53 @@ schedule !thread@Thread{ let thread' = thread { threadControl = ThreadControl (k x) ctl' } schedule thread' simstate + TimeoutFrame tmid isLockedRef k ctl' -> do + -- There is a possible race between timeout action and the timeout expiration. + -- We use a lock to solve the race. + -- + -- The lock starts 'NotLocked' and when the timeout fires the lock is + -- locked and asynchronously an assassin thread is coming to interrupt + -- it. If the lock is locked when the timeout is fired then nothing + -- happens. + -- + -- Knowing this, if we reached this point in the code and the lock is + -- 'Locked', then it means that this thread still hasn't received the + -- 'TimeoutException', so we need to kill the thread that is responsible + -- for doing that (the assassin thread, we need to defend ourselves!) + -- and run our continuation successfully and peacefully. We will do that + -- by uninterruptibly-masking ourselves so we can not receive any + -- exception and kill the assassin thread behind its back. + -- If the lock is 'NotLocked' then it means we can just acquire it and + -- carry on with the success case. + locked <- readSTRef isLockedRef + case locked of + Locked etid -> do + let -- Kill the assassin throwing thread and carry on the + -- continuation + thread' = + thread { threadControl = + ThreadControl (ThrowTo (toException ThreadKilled) + etid + (k (Just x))) + ctl' + , threadMasking = MaskedUninterruptible + } + schedule thread' simstate + + NotLocked -> do + -- Acquire lock + writeSTRef isLockedRef (Locked tid) + + -- Remove the timer from the queue + let timers' = PSQ.delete tmid timers + -- Run the continuation + thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' } + + schedule thread' simstate { timers = timers' + } Throw thrower e -> {-# SCC "schedule.Throw" #-} case unwindControlStack e thread of + -- Found a CatchFrame Right thread'@Thread { threadMasking = maskst' } -> do -- We found a suitable exception handler, continue with that trace <- schedule thread' simstate @@ -346,30 +389,88 @@ schedule !thread@Thread{ NewTimeout d k -> {-# SCC "schedule.NewTimeout.2" #-} do !tvar <- execNewTVar nextVid - (Just $ "<>") - TimeoutPending - !tvar' <- execNewTVar (succ nextVid) - (Just $ "<>") - False + (Just $ "<>") + TimeoutPending let !expiry = d `addTime` time - !t = Timeout tvar tvar' nextTmid - !timers' = PSQ.insert nextTmid expiry (TimerVars tvar tvar') timers + !t = Timeout tvar nextTmid + !timers' = PSQ.insert nextTmid expiry (Timer tvar) timers !thread' = thread { threadControl = ThreadControl (k t) ctl } - !trace <- schedule thread' simstate { timers = timers' - , nextVid = succ (succ nextVid) - , nextTmid = succ nextTmid } + trace <- schedule thread' simstate { timers = timers' + , nextVid = succ nextVid + , nextTmid = succ nextTmid } return (SimTrace time tid tlbl (EventTimerCreated nextTmid nextVid expiry) trace) + -- This case is guarded by checks in 'timeout' itself. + StartTimeout d _ _ | d <= 0 -> + error "schedule: StartTimeout: Impossible happened" + + StartTimeout d action' k -> + {-# SCC "schedule.StartTimeout" #-} do + isLockedRef <- newSTRef NotLocked + let !expiry = d `addTime` time + !timers' = PSQ.insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers + !thread' = thread { threadControl = + ThreadControl action' + (TimeoutFrame nextTmid isLockedRef k ctl) + } + !trace <- deschedule Yield thread' simstate { timers = timers' + , nextTmid = succ nextTmid } + return (SimTrace time tid tlbl (EventTimeoutCreated nextTmid tid expiry) trace) + + RegisterDelay d k | d < 0 -> + {-# SCC "schedule.NewRegisterDelay.1" #-} do + !tvar <- execNewTVar nextVid + (Just $ "<>") + True + let !expiry = d `addTime` time + !thread' = thread { threadControl = ThreadControl (k tvar) ctl } + trace <- schedule thread' simstate { nextVid = succ nextVid } + return (SimTrace time tid tlbl (EventRegisterDelayCreated nextTmid nextVid expiry) $ + SimTrace time tid tlbl (EventRegisterDelayFired nextTmid) $ + trace) + + RegisterDelay d k -> + {-# SCC "schedule.NewRegisterDelay.2" #-} do + !tvar <- execNewTVar nextVid + (Just $ "<>") + False + let !expiry = d `addTime` time + !timers' = PSQ.insert nextTmid expiry (TimerRegisterDelay tvar) timers + !thread' = thread { threadControl = ThreadControl (k tvar) ctl } + trace <- schedule thread' simstate { timers = timers' + , nextVid = succ nextVid + , nextTmid = succ nextTmid } + return (SimTrace time tid tlbl + (EventRegisterDelayCreated nextTmid nextVid expiry) trace) + + ThreadDelay d k | d < 0 -> + {-# SCC "schedule.NewThreadDelay" #-} do + let !expiry = d `addTime` time + !thread' = thread { threadControl = ThreadControl k ctl } + trace <- schedule thread' simstate + return (SimTrace time tid tlbl (EventThreadDelay expiry) $ + SimTrace time tid tlbl EventThreadDelayFired $ + trace) + + ThreadDelay d k -> + {-# SCC "schedule.NewThreadDelay" #-} do + let !expiry = d `addTime` time + !timers' = PSQ.insert nextTmid expiry (TimerThreadDelay tid) timers + !thread' = thread { threadControl = ThreadControl k ctl } + !trace <- deschedule Blocked thread' simstate { timers = timers' + , nextTmid = succ nextTmid } + return (SimTrace time tid tlbl (EventThreadDelay expiry) trace) + -- we do not follow `GHC.Event` behaviour here; updating a timer to the past -- effectively cancels it. - UpdateTimeout (Timeout _tvar _tvar' tmid) d k | d < 0 -> + UpdateTimeout (Timeout _tvar tmid) d k | d < 0 -> {-# SCC "schedule.UpdateTimeout" #-} do let !timers' = PSQ.delete tmid timers !thread' = thread { threadControl = ThreadControl k ctl } trace <- schedule thread' simstate { timers = timers' } return (SimTrace time tid tlbl (EventTimerCancelled tmid) trace) - UpdateTimeout (Timeout _tvar _tvar' tmid) d k -> + UpdateTimeout (Timeout _tvar tmid) d k -> {-# SCC "schedule.UpdateTimeout" #-} do -- updating an expired timeout is a noop, so it is safe -- to race using a timeout with updating or cancelling it @@ -387,12 +488,12 @@ schedule !thread@Thread{ let thread' = thread { threadControl = ThreadControl k ctl } schedule thread' simstate - CancelTimeout (Timeout tvar _tvar' tmid) k -> + CancelTimeout (Timeout tvar tmid) k -> {-# SCC "schedule.CancelTimeout" #-} do let !timers' = PSQ.delete tmid timers !thread' = thread { threadControl = ThreadControl k ctl } !written <- execAtomically' (runSTM $ writeTVar tvar TimeoutCancelled) - (!wakeup, wokeby) <- threadsUnblockedByWrites written + (wakeup, wokeby) <- threadsUnblockedByWrites written mapM_ (\(SomeTVar var) -> unblockAllThreadsFromTVar var) written let (unblocked, simstate') = unblockThreads wakeup simstate @@ -525,7 +626,7 @@ schedule !thread@Thread{ (runIOSim action') (MaskFrame k maskst ctl) , threadMasking = maskst' } - !trace <- + trace <- case maskst' of -- If we're now unmasked then check for any pending async exceptions Unmasked -> SimTrace time tid tlbl (EventDeschedule Interruptable) @@ -703,7 +804,7 @@ reschedule :: SimState s a -> ST s (SimTrace a) reschedule !simstate@SimState{ runqueue, threads } | Just (!tid, runqueue') <- Deque.uncons runqueue = {-# SCC "reschedule.Just" #-} - assert (invariant Nothing simstate) $ + invariant Nothing simstate $ let thread = threads Map.! tid in schedule thread simstate { runqueue = runqueue' @@ -713,7 +814,7 @@ reschedule !simstate@SimState{ runqueue, threads } -- timer event, or stop. reschedule !simstate@SimState{ threads, timers, curTime = time } = {-# SCC "reschedule.Nothing" #-} - assert (invariant Nothing simstate) $ + invariant Nothing simstate $ -- important to get all events that expire at this time case removeMinimums timers of @@ -723,30 +824,65 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } = -- Reuse the STM functionality here to write all the timer TVars. -- Simplify to a special case that only reads and writes TVars. - !written <- execAtomically' (runSTM $ mapM_ timeoutAction fired) - (wakeup, wokeby) <- threadsUnblockedByWrites written + !written <- execAtomically' (runSTM $ mapM_ timeoutSTMAction fired) + (wakeupSTM, wokeby) <- threadsUnblockedByWrites written !_ <- mapM_ (\(SomeTVar tvar) -> unblockAllThreadsFromTVar tvar) written - let (unblocked, - simstate') = unblockThreads wakeup simstate - !trace <- reschedule simstate' { curTime = time' - , timers = timers' } + -- Check all fired threadDelays + let wakeupThreadDelay = [ tid | TimerThreadDelay tid <- fired ] + wakeup = wakeupThreadDelay ++ wakeupSTM + (_, !simstate') = unblockThreads wakeup simstate + + -- For each 'timeout' action where the timeout has fired, start a + -- new thread to execute throwTo to interrupt the action. + !timeoutExpired = [ (tid, tmid, isLockedRef) + | TimerTimeout tid tmid isLockedRef <- fired ] + + -- Get the isLockedRef values + !timeoutExpired' <- traverse (\(tid, tmid, isLockedRef) -> do + locked <- readSTRef isLockedRef + return (tid, tmid, isLockedRef, locked) + ) + timeoutExpired + + !simstate'' <- forkTimeoutInterruptThreads timeoutExpired' simstate' + + !trace <- reschedule simstate'' { curTime = time' + , timers = timers' } + return $ - traceMany ([ (time', ThreadId [-1], Just "timer", EventTimerExpired tmid) - | tmid <- tmids ] + traceMany ([ ( time', ThreadId [-1], Just "timer" + , EventTimerFired tmid) + | (tmid, Timer _) <- zip tmids fired ] + ++ [ ( time', ThreadId [-1], Just "register delay timer" + , EventRegisterDelayFired tmid) + | (tmid, TimerRegisterDelay _) <- zip tmids fired ] ++ [ (time', tid', tlbl', EventTxWakeup vids) - | tid' <- unblocked + | tid' <- wakeupSTM , let tlbl' = lookupThreadLabel tid' threads - , let Just vids = Set.toList <$> Map.lookup tid' wokeby ]) + , let Just vids = Set.toList <$> Map.lookup tid' wokeby ] + ++ [ ( time', tid, Just "thread delay timer" + , EventThreadDelayFired) + | tid <- wakeupThreadDelay ] + ++ [ ( time', tid, Just "timeout timer" + , EventTimeoutFired tmid) + | (tid, tmid, _, _) <- timeoutExpired' ] + ++ [ ( time', tid, Just "thread forked" + , EventThreadForked tid) + | (tid, _, _, _) <- timeoutExpired' ]) trace where - timeoutAction (TimerVars var bvar) = do + timeoutSTMAction (Timer var) = do x <- readTVar var case x of - TimeoutPending -> writeTVar var TimeoutFired - >> writeTVar bvar True + TimeoutPending -> writeTVar var TimeoutFired TimeoutFired -> error "MonadTimer(Sim): invariant violation" TimeoutCancelled -> return () + timeoutSTMAction (TimerRegisterDelay var) = writeTVar var True + -- Note that 'threadDelay' is not handled via STM style wakeup, but rather + -- it's handled directly above with 'wakeupThreadDelay' and 'unblockThreads' + timeoutSTMAction TimerThreadDelay{} = return () + timeoutSTMAction TimerTimeout{} = return () unblockThreads :: [ThreadId] -> SimState s a -> ([ThreadId], SimState s a) unblockThreads !wakeup !simstate@SimState {runqueue, threads} = @@ -767,7 +903,76 @@ unblockThreads !wakeup !simstate@SimState {runqueue, threads} = -- and in which case we mark them as now running !threads' = List.foldl' (flip (Map.adjust (\t -> t { threadBlocked = False }))) - threads unblocked + threads + unblocked + +-- | This function receives a list of TimerTimeout values that represent threads +-- for which the timeout expired and kills the running thread if needed. +-- +-- This function is responsible for the second part of the race condition issue +-- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is +-- where the assassin threads are launched. So, as explained previously, at this +-- point in code, the timeout expired so we need to interrupt the running +-- thread. If the running thread finished at the same time the timeout expired +-- we have a race condition. To deal with this race condition what we do is +-- look at the lock value. If it is 'Locked' this means that the running thread +-- already finished (or won the race) so we can safely do nothing. Otherwise, if +-- the lock value is 'NotLocked' we need to acquire the lock and launch an +-- assassin thread that is going to interrupt the running one. Note that we +-- should run this interrupting thread in an unmasked state since it might +-- receive a 'ThreadKilled' exception. +-- +forkTimeoutInterruptThreads :: [(ThreadId, TimeoutId, STRef s IsLocked, IsLocked)] + -> SimState s a + -> ST s (SimState s a) +forkTimeoutInterruptThreads timeoutExpired simState@SimState {threads} = + foldlM (\st@SimState{ runqueue = runqueue, + threads = threads' + } + (t, isLockedRef) + -> do + let tid' = threadId t + threads'' = Map.insert tid' t threads' + runqueue' = Deque.snoc tid' runqueue + + writeSTRef isLockedRef (Locked tid') + + return st { runqueue = runqueue', + threads = threads'' + }) + simState + throwToThread + + where + -- can only throw exception if the thread exists and if the mutually + -- exclusive lock exists and is still 'NotLocked' + toThrow = [ (tid, tmid, ref, t) + | (tid, tmid, ref, locked) <- timeoutExpired + , Just t <- [Map.lookup tid threads] + , NotLocked <- [locked] + ] + -- we launch a thread responsible for throwing an AsyncCancelled exception + -- to the thread which timeout expired + throwToThread = + [ let nextId = threadNextTId t + tid' = childThreadId tid nextId + in ( Thread { threadId = tid', + threadControl = + ThreadControl + (ThrowTo (toException (TimeoutException tmid)) + tid + (Return ())) + ForkFrame, + threadBlocked = False, + threadMasking = Unmasked, + threadThrowTo = [], + threadClockId = threadClockId t, + threadLabel = Just "timeout-forked-thread", + threadNextTId = 1 + } + , ref ) + | (tid, tmid, ref, t) <- toThrow + ] -- | Iterate through the control stack to find an enclosing exception handler @@ -785,7 +990,8 @@ unwindControlStack e thread = ThreadControl _ ctl -> unwind (threadMasking thread) ctl where unwind :: forall s' c. MaskingState - -> ControlStack s' c a -> Either Bool (Thread s' a) + -> ControlStack s' c a + -> Either Bool (Thread s' a) unwind _ MainFrame = Left True unwind _ ForkFrame = Left False unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl @@ -797,12 +1003,28 @@ unwindControlStack e thread = -- Ok! We will be able to continue the thread with the handler -- followed by the continuation after the catch - Just e' -> Right thread { - -- As per async exception rules, the handler is run masked + Just e' -> Right ( thread { + -- As per async exception rules, the handler is run + -- masked threadControl = ThreadControl (handler e') (MaskFrame k maskst ctl), threadMasking = atLeastInterruptibleMask maskst } + ) + + -- Either Timeout fired or the action threw an exception. + -- - If Timeout fired, then it was possibly during this thread's execution + -- so we need to run the continuation with a Nothing value. + -- - If the timeout action threw an exception we need to keep unwinding the + -- control stack looking for a handler to this exception. + unwind maskst (TimeoutFrame tmid _ k ctl) = + case fromException e of + -- Exception came from timeout expiring + Just (TimeoutException tmid') -> + assert (tmid == tmid') + Right thread { threadControl = ThreadControl (k Nothing) ctl } + -- Exception came from a different exception + _ -> unwind maskst ctl atLeastInterruptibleMask :: MaskingState -> MaskingState atLeastInterruptibleMask Unmasked = MaskedInterruptible diff --git a/io-sim/src/Control/Monad/IOSim/InternalTypes.hs b/io-sim/src/Control/Monad/IOSim/InternalTypes.hs index 3020c5d8..554ab2a9 100644 --- a/io-sim/src/Control/Monad/IOSim/InternalTypes.hs +++ b/io-sim/src/Control/Monad/IOSim/InternalTypes.hs @@ -1,17 +1,20 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} -- | Internal types shared between `IOSim` and `IOSimPOR`. -- module Control.Monad.IOSim.InternalTypes ( ThreadControl (..) , ControlStack (..) + , IsLocked (..) ) where +import Data.STRef.Lazy (STRef) import Control.Exception (Exception) import Control.Monad.Class.MonadThrow (MaskingState (..)) -import Control.Monad.IOSim.Types (SimA) +import Control.Monad.IOSim.Types (SimA, ThreadId, TimeoutId) -- We hide the type @b@ here, so it's useful to bundle these two parts together, -- rather than having Thread have an existential type, which makes record @@ -25,29 +28,43 @@ instance Show (ThreadControl s a) where show _ = "..." data ControlStack s b a where - MainFrame :: ControlStack s a a - ForkFrame :: ControlStack s () a - MaskFrame :: (b -> SimA s c) -- subsequent continuation - -> !MaskingState -- thread local state to restore + MainFrame :: ControlStack s a a + ForkFrame :: ControlStack s () a + MaskFrame :: (b -> SimA s c) -- subsequent continuation + -> MaskingState -- thread local state to restore -> !(ControlStack s c a) - -> ControlStack s b a - CatchFrame :: Exception e - => (e -> SimA s b) -- exception continuation - -> (b -> SimA s c) -- subsequent continuation + -> ControlStack s b a + CatchFrame :: Exception e + => (e -> SimA s b) -- exception continuation + -> (b -> SimA s c) -- subsequent continuation -> !(ControlStack s c a) - -> ControlStack s b a + -> ControlStack s b a + TimeoutFrame :: TimeoutId + -> STRef s IsLocked + -> (Maybe b -> SimA s c) + -> !(ControlStack s c a) + -> ControlStack s b a instance Show (ControlStack s b a) where show = show . dash - where dash :: ControlStack s' b' a' -> ControlStackDash - dash MainFrame = MainFrame' - dash ForkFrame = ForkFrame' - dash (MaskFrame _ m s) = MaskFrame' m (dash s) - dash (CatchFrame _ _ s) = CatchFrame' (dash s) + where + dash :: ControlStack s b' a -> ControlStackDash + dash MainFrame = MainFrame' + dash ForkFrame = ForkFrame' + dash (MaskFrame _ m cs) = MaskFrame' m (dash cs) + dash (CatchFrame _ _ cs) = CatchFrame' (dash cs) + dash (TimeoutFrame tmid _ _ cs) = TimeoutFrame' tmid (dash cs) data ControlStackDash = MainFrame' | ForkFrame' | MaskFrame' MaskingState ControlStackDash | CatchFrame' ControlStackDash + -- TODO: Figure out a better way to include IsLocked here + | TimeoutFrame' TimeoutId ControlStackDash + | ThreadDelayFrame' TimeoutId ControlStackDash deriving Show + +data IsLocked = NotLocked | Locked !ThreadId + deriving (Eq, Show) + diff --git a/io-sim/src/Control/Monad/IOSim/STM.hs b/io-sim/src/Control/Monad/IOSim/STM.hs index dd28e758..4d99943c 100644 --- a/io-sim/src/Control/Monad/IOSim/STM.hs +++ b/io-sim/src/Control/Monad/IOSim/STM.hs @@ -23,7 +23,7 @@ newtype TQueueDefault m a = TQueue (TVar m ([a], [a])) labelTQueueDefault :: MonadLabelledSTM m => TQueueDefault m a -> String -> STM m () -labelTQueueDefault (TQueue queue) label = labelTVar queue label +labelTQueueDefault (TQueue queue) label = labelTVar queue label traceTQueueDefault :: MonadTraceSTM m diff --git a/io-sim/src/Control/Monad/IOSim/Types.hs b/io-sim/src/Control/Monad/IOSim/Types.hs index 1d625727..af54e0ad 100644 --- a/io-sim/src/Control/Monad/IOSim/Types.hs +++ b/io-sim/src/Control/Monad/IOSim/Types.hs @@ -1,12 +1,10 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTSyntax #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PatternSynonyms #-} @@ -150,9 +148,15 @@ data SimA s a where SetWallTime :: UTCTime -> SimA s b -> SimA s b UnshareClock :: SimA s b -> SimA s b - NewTimeout :: DiffTime -> (Timeout (IOSim s) -> SimA s b) -> SimA s b - UpdateTimeout:: Timeout (IOSim s) -> DiffTime -> SimA s b -> SimA s b - CancelTimeout:: Timeout (IOSim s) -> SimA s b -> SimA s b + StartTimeout :: DiffTime -> SimA s a -> (Maybe a -> SimA s b) -> SimA s b + + RegisterDelay :: DiffTime -> (TVar s Bool -> SimA s b) -> SimA s b + + ThreadDelay :: DiffTime -> SimA s b -> SimA s b + + NewTimeout :: DiffTime -> (Timeout (IOSim s) -> SimA s b) -> SimA s b + UpdateTimeout :: Timeout (IOSim s) -> DiffTime -> SimA s b -> SimA s b + CancelTimeout :: Timeout (IOSim s) -> SimA s b -> SimA s b Throw :: Thrower -> SomeException -> SimA s a Catch :: Exception e => @@ -542,44 +546,28 @@ unshareClock :: IOSim s () unshareClock = IOSim $ oneShot $ \k -> UnshareClock (k ()) instance MonadDelay (IOSim s) where - -- Use default in terms of MonadTimer + -- Use optimized IOSim primitive + threadDelay d = IOSim $ oneShot $ \k -> ThreadDelay d (k ()) instance MonadTimer (IOSim s) where - data Timeout (IOSim s) = Timeout !(TVar s TimeoutState) !(TVar s Bool) !TimeoutId - -- ^ a timeout; we keep both 'TVar's to support - -- `newTimer` and 'registerTimeout'. + data Timeout (IOSim s) = Timeout !(TVar s TimeoutState) !TimeoutId + -- ^ a timeout | NegativeTimeout !TimeoutId -- ^ a negative timeout - readTimeout (Timeout var _bvar _key) = MonadSTM.readTVar var - readTimeout (NegativeTimeout _key) = pure TimeoutCancelled + readTimeout (Timeout var _key) = MonadSTM.readTVar var + readTimeout (NegativeTimeout _key) = pure TimeoutCancelled newTimeout d = IOSim $ oneShot $ \k -> NewTimeout d k updateTimeout t d = IOSim $ oneShot $ \k -> UpdateTimeout t d (k ()) cancelTimeout t = IOSim $ oneShot $ \k -> CancelTimeout t (k ()) timeout d action - | d < 0 = Just <$> action - | d == 0 = return Nothing - | otherwise = do - pid <- myThreadId - t@(Timeout _ _ tid) <- newTimeout d - handleJust - (\(TimeoutException tid') -> if tid' == tid - then Just () - else Nothing) - (\_ -> return Nothing) $ - bracket - (forkIO $ do - labelThisThread "<>" - fired <- MonadSTM.atomically $ awaitTimeout t - when fired $ throwTo pid (TimeoutException tid)) - (\pid' -> do - cancelTimeout t - throwTo pid' AsyncCancelled) - (\_ -> Just <$> action) - - registerDelay d = IOSim $ oneShot $ \k -> NewTimeout d (\(Timeout _var bvar _) -> k bvar) + | d < 0 = Just <$> action + | d == 0 = return Nothing + | otherwise = IOSim $ oneShot $ \k -> StartTimeout d (runIOSim action) k + + registerDelay d = IOSim $ oneShot $ \k -> RegisterDelay d k newtype TimeoutException = TimeoutException TimeoutId deriving Eq @@ -806,10 +794,19 @@ data SimEventType (Maybe Effect) -- effect performed (only for `IOSimPOR`) | EventTxWakeup [Labelled TVarId] -- changed vars causing retry - | EventTimerCreated TimeoutId TVarId Time - | EventTimerUpdated TimeoutId Time - | EventTimerCancelled TimeoutId - | EventTimerExpired TimeoutId + | EventThreadDelay Time + | EventThreadDelayFired + + | EventTimeoutCreated TimeoutId ThreadId Time + | EventTimeoutFired TimeoutId + + | EventRegisterDelayCreated TimeoutId TVarId Time + | EventRegisterDelayFired TimeoutId + + | EventTimerCreated TimeoutId TVarId Time + | EventTimerUpdated TimeoutId Time + | EventTimerCancelled TimeoutId + | EventTimerFired TimeoutId -- the following events are inserted to mark the difference between -- a failed trace and a similar passing trace of the same action diff --git a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs index 62e55ec7..64020326 100644 --- a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs +++ b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs @@ -1,20 +1,14 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTSyntax #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-orphans #-} @@ -56,7 +50,7 @@ module Control.Monad.IOSimPOR.Internal import Prelude hiding (read) import Data.Dynamic -import Data.Foldable (traverse_) +import Data.Foldable (traverse_, foldlM) import qualified Data.List as List import qualified Data.List.Trace as Trace import Data.Map.Strict (Map) @@ -69,10 +63,9 @@ import Data.Set (Set) import qualified Data.Set as Set import Data.Time (UTCTime (..), fromGregorian) -import Control.Exception (NonTermination (..), assert, throw) -import Control.Monad (join) - -import Control.Monad (when) +import Control.Exception + (NonTermination (..), assert, throw, AsyncException (..)) +import Control.Monad ( join, when ) import Control.Monad.ST.Lazy import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST) import Data.STRef.Lazy @@ -160,9 +153,13 @@ labelledThreads threadMap = -- | Timers mutable variables. First one supports 'newTimeout' api, the second --- one 'registerDelay'. +-- one 'registerDelay', the third one 'threadDelay'. -- -data TimerVars s = TimerVars !(TVar s TimeoutState) !(TVar s Bool) +data TimerCompletionInfo s = + Timer !(TVar s TimeoutState) + | TimerRegisterDelay !(TVar s Bool) + | TimerThreadDelay !ThreadId + | TimerTimeout !ThreadId !TimeoutId !(STRef s IsLocked) type RunQueue = OrdPSQ (Down ThreadId) (Down ThreadId) () @@ -177,9 +174,10 @@ data SimState s a = SimState { finished :: !(Map ThreadId (FinishedReason, VectorClock)), -- | current time curTime :: !Time, - -- | ordered list of timers - timers :: !(OrdPSQ TimeoutId Time (TimerVars s)), - -- | list of clocks + -- | ordered list of timers and timeouts + timers :: !(OrdPSQ TimeoutId Time (TimerCompletionInfo s)), + -- | timeout locks in order to synchronize the timeout handler and the + -- main thread clocks :: !(Map ClockId UTCTime), nextVid :: !TVarId, -- ^ next unused 'TVarId' nextTmid :: !TimeoutId, -- ^ next unused 'TimeoutId' @@ -337,7 +335,53 @@ schedule thread@Thread{ let thread' = thread { threadControl = ThreadControl (k x) ctl' } schedule thread' simstate + TimeoutFrame tmid isLockedRef k ctl' -> do + -- It could happen that the timeout action finished at the same time + -- as the timeout expired, this will be a race condition. That's why + -- we have the locks to solve this. + -- + -- The lock starts 'NotLocked' and when the timeout fires the lock is + -- locked and asynchronously an assassin thread is coming to interrupt + -- this one. If the lock is locked when the timeout is fired then nothing + -- happens. + -- + -- Knowing this, if we reached this point in the code and the lock is + -- 'Locked', then it means that this thread still hasn't received the + -- 'TimeoutException', so we need to kill the thread that is responsible + -- for doing that (the assassin one, we need to defend ourselves!) + -- and run our continuation successfully and peacefully. We will do that + -- by uninterruptibly-masking ourselves so we can not receive any + -- exception and kill the assassin thread behind its back. + -- If the lock is 'NotLocked' then it means we can just acquire it and + -- carry on with the success case. + locked <- readSTRef isLockedRef + case locked of + Locked etid -> do + let -- Kill the exception throwing thread and carry on the + -- continuation + thread' = + thread { threadControl = + ThreadControl (ThrowTo (toException ThreadKilled) + etid + (k (Just x))) + ctl' + , threadMasking = MaskedUninterruptible + } + schedule thread' simstate + + NotLocked -> do + -- Acquire lock + writeSTRef isLockedRef (Locked tid) + + -- Remove the timer from the queue + let timers' = PSQ.delete tmid timers + -- Run the continuation successfully + thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' } + + schedule thread' simstate { timers = timers' + } Throw thrower e -> case unwindControlStack e thread of + -- Found a CatchFrame Right thread0@Thread { threadMasking = maskst' } -> do -- We found a suitable exception handler, continue with that -- We record a step, in case there is no exception handler on replay. @@ -451,28 +495,82 @@ schedule thread@Thread{ (Just $ "<>") TimeoutPending modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock) - tvar' <- execNewTVar (succ nextVid) - (Just $ "<>") - False - modifySTRef (tvarVClock tvar') (leastUpperBoundVClock vClock) let expiry = d `addTime` time - t = Timeout tvar tvar' nextTmid - timers' = PSQ.insert nextTmid expiry (TimerVars tvar tvar') timers + t = Timeout tvar nextTmid + timers' = PSQ.insert nextTmid expiry (Timer tvar) timers thread' = thread { threadControl = ThreadControl (k t) ctl } - !trace <- schedule thread' simstate { timers = timers' + trace <- schedule thread' simstate { timers = timers' , nextVid = succ (succ nextVid) , nextTmid = succ nextTmid } return (SimPORTrace time tid tstep tlbl (EventTimerCreated nextTmid nextVid expiry) trace) + -- This case is guarded by checks in 'timeout' itself. + StartTimeout d _ _ | d <= 0 -> + error "schedule: StartTimeout: Impossible happened" + + StartTimeout d action' k -> do + isLockedRef <- newSTRef NotLocked + let expiry = d `addTime` time + timers' = PSQ.insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers + thread' = thread { threadControl = + ThreadControl action' + (TimeoutFrame nextTmid isLockedRef k ctl) + } + trace <- deschedule Yield thread' simstate { timers = timers' + , nextTmid = succ nextTmid } + return (SimPORTrace time tid tstep tlbl (EventTimeoutCreated nextTmid tid expiry) trace) + + RegisterDelay d k | d < 0 -> do + tvar <- execNewTVar nextVid + (Just $ "<>") + True + modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock) + let !expiry = d `addTime` time + !thread' = thread { threadControl = ThreadControl (k tvar) ctl } + trace <- schedule thread' simstate { nextVid = succ nextVid } + return (SimPORTrace time tid tstep tlbl (EventRegisterDelayCreated nextTmid nextVid expiry) $ + SimPORTrace time tid tstep tlbl (EventRegisterDelayFired nextTmid) $ + trace) + + RegisterDelay d k -> do + tvar <- execNewTVar nextVid + (Just $ "<>") + False + modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock) + let !expiry = d `addTime` time + !timers' = PSQ.insert nextTmid expiry (TimerRegisterDelay tvar) timers + !thread' = thread { threadControl = ThreadControl (k tvar) ctl } + trace <- schedule thread' simstate { timers = timers' + , nextVid = succ nextVid + , nextTmid = succ nextTmid } + return (SimPORTrace time tid tstep tlbl + (EventRegisterDelayCreated nextTmid nextVid expiry) trace) + + ThreadDelay d k | d < 0 -> do + let expiry = d `addTime` time + thread' = thread { threadControl = ThreadControl k ctl } + trace <- schedule thread' simstate + return (SimPORTrace time tid tstep tlbl (EventThreadDelay expiry) $ + SimPORTrace time tid tstep tlbl EventThreadDelayFired $ + trace) + + ThreadDelay d k -> do + let expiry = d `addTime` time + timers' = PSQ.insert nextTmid expiry (TimerThreadDelay tid) timers + thread' = thread { threadControl = ThreadControl k ctl } + trace <- deschedule Blocked thread' simstate { timers = timers' + , nextTmid = succ nextTmid } + return (SimPORTrace time tid tstep tlbl (EventThreadDelay expiry) trace) + -- we do not follow `GHC.Event` behaviour here; updating a timer to the past -- effectively cancels it. - UpdateTimeout (Timeout _tvar _tvar' tmid) d k | d < 0 -> do + UpdateTimeout (Timeout _tvar tmid) d k | d < 0 -> do let timers' = PSQ.delete tmid timers thread' = thread { threadControl = ThreadControl k ctl } trace <- schedule thread' simstate { timers = timers' } return (SimPORTrace time tid tstep tlbl (EventTimerCancelled tmid) trace) - UpdateTimeout (Timeout _tvar _tvar' tmid) d k -> do + UpdateTimeout (Timeout _tvar tmid) d k -> do -- updating an expired timeout is a noop, so it is safe -- to race using a timeout with updating or cancelling it let updateTimeout_ Nothing = ((), Nothing) @@ -488,7 +586,7 @@ schedule thread@Thread{ let thread' = thread { threadControl = ThreadControl k ctl } schedule thread' simstate - CancelTimeout (Timeout tvar tvar' tmid) k -> do + CancelTimeout (Timeout tvar tmid) k -> do let timers' = PSQ.delete tmid timers written <- execAtomically' (runSTM $ writeTVar tvar TimeoutCancelled) (wakeup, wokeby) <- threadsUnblockedByWrites written @@ -503,7 +601,6 @@ schedule thread@Thread{ (unblocked, simstate') = unblockThreads vClock wakeup simstate modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock) - modifySTRef (tvarVClock tvar') (leastUpperBoundVClock vClock) !trace <- deschedule Yield thread' simstate' { timers = timers' } return $ SimPORTrace time tid tstep tlbl (EventTimerCancelled tmid) $ traceMany @@ -935,34 +1032,65 @@ reschedule simstate@SimState{ threads, timers, curTime = time, races } = -- Reuse the STM functionality here to write all the timer TVars. -- Simplify to a special case that only reads and writes TVars. written <- execAtomically' (runSTM $ mapM_ timeoutAction fired) - (wakeup, wokeby) <- threadsUnblockedByWrites written + (wakeupSTM, wokeby) <- threadsUnblockedByWrites written mapM_ (\(SomeTVar tvar) -> unblockAllThreadsFromTVar tvar) written - -- TODO: the vector clock below cannot be right, can it? - let (unblocked, - simstate') = unblockThreads bottomVClock wakeup simstate - -- all open races will be completed and reported at this time - simstate'' = simstate'{ races = noRaces } + let wakeupThreadDelay = [ tid | TimerThreadDelay tid <- fired ] + wakeup = wakeupThreadDelay ++ wakeupSTM + -- TODO: the vector clock below cannot be right, can it? + (_, !simstate') = unblockThreads bottomVClock wakeup simstate + + -- For each 'timeout' action where the timeout has fired, start a + -- new thread to execute throwTo to interrupt the action. + !timeoutExpired = [ (tid, tmid, isLockedRef) + | TimerTimeout tid tmid isLockedRef <- fired ] + + -- Get the isLockedRef values + !timeoutExpired' <- traverse (\(tid, tmid, isLockedRef) -> do + locked <- readSTRef isLockedRef + return (tid, tmid, isLockedRef, locked) + ) + timeoutExpired + + -- all open races will be completed and reported at this time + !simstate'' <- forkTimeoutInterruptThreads timeoutExpired' + simstate' { races = noRaces } !trace <- reschedule simstate'' { curTime = time' , timers = timers' } let traceEntries = - [ (time', ThreadId [-1], (-1), Just "timer", EventTimerExpired tmid) - | tmid <- tmids ] - ++ [ (time', tid', (-1), tlbl', EventTxWakeup vids) - | tid' <- unblocked + [ ( time', ThreadId [-1], -1, Just "timer" + , EventTimerFired tmid) + | (tmid, Timer _) <- zip tmids fired ] + ++ [ ( time', ThreadId [-1], -1, Just "register delay timer" + , EventRegisterDelayFired tmid) + | (tmid, TimerRegisterDelay _) <- zip tmids fired ] + ++ [ (time', tid', -1, tlbl', EventTxWakeup vids) + | tid' <- wakeupSTM , let tlbl' = lookupThreadLabel tid' threads , let Just vids = Set.toList <$> Map.lookup tid' wokeby ] + ++ [ ( time', tid, -1, Just "thread delay timer" + , EventThreadDelayFired) + | tid <- wakeupThreadDelay ] + ++ [ ( time', tid, -1, Just "timeout timer" + , EventTimeoutFired tmid) + | (tid, tmid, _, _) <- timeoutExpired' ] + ++ [ ( time', tid, -1, Just "forked thread" + , EventThreadForked tid) + | (tid, _, _, _) <- timeoutExpired' ] + return $ traceFinalRacesFound simstate $ traceMany traceEntries trace where - timeoutAction (TimerVars var bvar) = do + timeoutAction (Timer var) = do x <- readTVar var case x of - TimeoutPending -> writeTVar var TimeoutFired - >> writeTVar bvar True + TimeoutPending -> writeTVar var TimeoutFired TimeoutFired -> error "MonadTimer(Sim): invariant violation" TimeoutCancelled -> return () + timeoutAction (TimerRegisterDelay var) = writeTVar var True + timeoutAction (TimerThreadDelay _) = return () + timeoutAction (TimerTimeout _ _ _) = return () unblockThreads :: forall s a. VectorClock @@ -998,6 +1126,78 @@ unblockThreads vClock wakeup simstate@SimState {runqueue, threads} = threadVClock = vClock `leastUpperBoundVClock` threadVClock t }))) threads unblockedIds +-- | This function receives a list of TimerTimeout values that represent threads +-- for which the timeout expired and kills the running thread if needed. +-- +-- This function is responsible for the second part of the race condition issue +-- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is +-- where the assassin threads are launched. So, as explained previously, at this +-- point in code, the timeout expired so we need to interrupt the running +-- thread. If the running thread finished at the same time the timeout expired +-- we have a race condition. To deal with this race condition what we do is +-- look at the lock value. If it is 'Locked' this means that the running thread +-- already finished (or won the race) so we can safely do nothing. Otherwise, if +-- the lock value is 'NotLocked' we need to acquire the lock and launch an +-- assassin thread that is going to interrupt the running one. Note that we +-- should run this interrupting thread in an unmasked state since it might +-- receive a 'ThreadKilled' exception. +-- +forkTimeoutInterruptThreads :: [(ThreadId, TimeoutId, STRef s IsLocked, IsLocked)] + -> SimState s a + -> ST s (SimState s a) +forkTimeoutInterruptThreads timeoutExpired simState@SimState {threads} = + foldlM (\st@SimState{ runqueue = runqueue, + threads = threads' + } + (t, isLockedRef) + -> do + let tid' = threadId t + threads'' = Map.insert tid' t threads' + runqueue' = insertThread t runqueue + writeSTRef isLockedRef (Locked tid') + + return st { runqueue = runqueue', + threads = threads'' + }) + simState + throwToThread + + where + -- can only throw exception if the thread exists and if the mutually + -- exclusive lock exists and is still 'NotLocked' + toThrow = [ (tid, tmid, ref, t) + | (tid, tmid, ref, locked) <- timeoutExpired + , Just t <- [Map.lookup tid threads] + , NotLocked <- [locked] + ] + -- we launch a thread responsible for throwing an AsyncCancelled exception + -- to the thread which timeout expired + throwToThread = + [ let nextId = threadNextTId t + tid' = childThreadId tid nextId + in ( Thread { threadId = tid', + threadControl = + ThreadControl + (ThrowTo (toException (TimeoutException tmid)) + tid + (Return ())) + ForkFrame, + threadBlocked = False, + threadDone = False, + threadMasking = Unmasked, + threadThrowTo = [], + threadClockId = threadClockId t, + threadLabel = Just "timeout-forked-thread", + threadNextTId = 1, + threadStep = 0, + threadVClock = insertVClock tid' 0 + $ threadVClock t, + threadEffect = mempty, + threadRacy = threadRacy t + } + , ref) + | (tid, tmid, ref, t) <- toThrow + ] -- | Iterate through the control stack to find an enclosing exception handler -- of the right type, or unwind all the way to the top level for the thread. @@ -1014,7 +1214,8 @@ unwindControlStack e thread = ThreadControl _ ctl -> unwind (threadMasking thread) ctl where unwind :: forall s' c. MaskingState - -> ControlStack s' c a -> Either Bool (Thread s' a) + -> ControlStack s' c a + -> Either Bool (Thread s' a) unwind _ MainFrame = Left True unwind _ ForkFrame = Left False unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl @@ -1026,12 +1227,28 @@ unwindControlStack e thread = -- Ok! We will be able to continue the thread with the handler -- followed by the continuation after the catch - Just e' -> Right thread { - -- As per async exception rules, the handler is run masked + Just e' -> Right ( thread { + -- As per async exception rules, the handler is run + -- masked threadControl = ThreadControl (handler e') (MaskFrame k maskst ctl), threadMasking = atLeastInterruptibleMask maskst } + ) + + -- Either Timeout fired or the action threw an exception. + -- - If Timeout fired, then it was possibly during this thread's execution + -- so we need to run the continuation with a Nothing value. + -- - If the timeout action threw an exception we need to keep unwinding the + -- control stack looking for a handler to this exception. + unwind maskst (TimeoutFrame tmid isLockedRef k ctl) = + case fromException e of + -- Exception came from timeout expiring + Just (TimeoutException tmid') -> + assert (tmid == tmid') + Right thread { threadControl = ThreadControl (k Nothing) ctl } + -- Exception came from a different exception + _ -> unwind maskst ctl atLeastInterruptibleMask :: MaskingState -> MaskingState atLeastInterruptibleMask Unmasked = MaskedInterruptible diff --git a/io-sim/test/Test/STM.hs b/io-sim/test/Test/STM.hs index d16ce71b..10c8a2d5 100644 --- a/io-sim/test/Test/STM.hs +++ b/io-sim/test/Test/STM.hs @@ -3,7 +3,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-}