@@ -48,6 +48,8 @@ module Control.Monad.IOSim.Internal
4848
4949import Prelude hiding (read )
5050
51+ import Data.Deque.Strict (Deque )
52+ import qualified Data.Deque.Strict as Deque
5153import Data.Dynamic
5254import Data.Foldable (foldlM , toList , traverse_ )
5355import qualified Data.List as List
@@ -60,8 +62,6 @@ import qualified Data.OrdPSQ as PSQ
6062import Data.Set (Set )
6163import qualified Data.Set as Set
6264import Data.Time (UTCTime (.. ), fromGregorian )
63- import Data.Deque.Strict (Deque )
64- import qualified Data.Deque.Strict as Deque
6565
6666import Control.Exception (NonTermination (.. ), assert , throw )
6767import Control.Monad (join , when )
@@ -76,13 +76,16 @@ import Control.Monad.Class.MonadSTM hiding (STM)
7676import Control.Monad.Class.MonadSTM.Internal (TMVarDefault (TMVar ))
7777import Control.Monad.Class.MonadThrow hiding (getMaskingState )
7878import Control.Monad.Class.MonadTime
79- import Control.Monad.Class.MonadTimer.SI (TimeoutState (.. ), DiffTime , diffTimeToMicrosecondsAsInt , microsecondsAsIntToDiffTime )
79+ import Control.Monad.Class.MonadTimer.SI (DiffTime , TimeoutState (.. ),
80+ diffTimeToMicrosecondsAsInt , microsecondsAsIntToDiffTime )
8081
8182import Control.Monad.IOSim.InternalTypes
8283import Control.Monad.IOSim.Types hiding (SimEvent (SimPOREvent ),
8384 Trace (SimPORTrace ))
8485import Control.Monad.IOSim.Types (SimEvent )
85- import System.Random (StdGen , randomR , split )
86+ import Data.Bifunctor (first )
87+ import Data.Ord (comparing )
88+ import System.Random (StdGen , randomR , split )
8689
8790--
8891-- Simulation interpreter
@@ -849,31 +852,47 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
849852 timeoutSTMAction TimerTimeout {} = return ()
850853
851854unblockThreads :: Bool -> [IOSimThreadId ] -> SimState s a -> ([IOSimThreadId ], SimState s a )
852- unblockThreads ! onlySTM ! wakeup ! simstate@ SimState {runqueue, threads, stdGen} =
855+ unblockThreads ! onlySTM ! wakeup simstate@ SimState {runqueue, threads, stdGen} =
853856 -- To preserve our invariants (that threadBlocked is correct)
854857 -- we update the runqueue and threads together here
855858 (unblocked, simstate {
856- runqueue = Deque. fromList (shuffledRunqueue ++ rest) ,
859+ runqueue = runqueue <> Deque. fromList unblocked ,
857860 threads = threads',
858861 stdGen = stdGen''
859862 })
860863 where
861- ! (shuffledRunqueue, stdGen'') = fisherYatesShuffle stdGen' toShuffle
862- ! ((toShuffle, rest), stdGen') =
863- let runqueueList = Deque. toList $ runqueue <> Deque. fromList unblocked
864- runqueueListLength = max 1 (length runqueueList)
865- (ix, newGen) = randomR (0 , runqueueListLength `div` 2 ) stdGen
866- in (splitAt ix runqueueList, newGen)
867864 -- can only unblock if the thread exists and is blocked (not running)
868- ! unblocked = [ tid
869- | tid <- wakeup
870- , case Map. lookup tid threads of
871- Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
872- -> True
873- Just Thread { threadStatus = ThreadBlocked _ }
874- -> not onlySTM
875- _ -> False
876- ]
865+ ! blockedOnOther = [ (tid, ix)
866+ | (tid, ix) <- zip wakeup [0 :: Int .. ]
867+ , case Map. lookup tid threads of
868+ Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
869+ -> False
870+ Just Thread { threadStatus = ThreadBlocked _ }
871+ -> not onlySTM
872+ _ -> False
873+ ]
874+
875+ ! blockedOnSTM = [ (tid, ix)
876+ | (tid, ix) <- zip wakeup [0 :: Int .. ]
877+ , case Map. lookup tid threads of
878+ Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
879+ -> True
880+ _ -> False
881+ ]
882+
883+ mergeByIndex :: Ord a => [(b , a )] -> [(b , a )] -> [b ]
884+ mergeByIndex a b = map fst $ List. sortBy (comparing snd ) (a ++ b)
885+
886+ -- Shuffle only 1/5th of the time
887+ (shouldShuffle, ! stdGen') =
888+ first (== 0 ) $ randomR (0 :: Int , 5 ) stdGen
889+
890+ (! shuffledBlockedOnSTM, ! stdGen'')
891+ | shouldShuffle = fisherYatesShuffle stdGen' blockedOnSTM
892+ | otherwise = (blockedOnSTM, stdGen')
893+
894+ ! unblocked = mergeByIndex blockedOnOther shuffledBlockedOnSTM
895+
877896 -- and in which case we mark them as now running
878897 ! threads' = List. foldl'
879898 (flip (Map. adjust (\ t -> t { threadStatus = ThreadRunning })))
@@ -889,8 +908,8 @@ unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
889908 where
890909 go 0 lst g = (lst, g)
891910 go n lst g = let (k, newGen) = randomR (0 , n) g
892- (x: xs) = drop k lst
893- swapped = take k lst ++ [lst !! n] ++ drop (k + 1 ) lst
911+ (x: xs) = drop k lst
912+ swapped = take k lst ++ [lst !! n] ++ drop (k + 1 ) lst
894913 in go (n - 1 ) (take n swapped ++ [x] ++ drop n xs) newGen
895914
896915-- | This function receives a list of TimerTimeout values that represent threads
0 commit comments