Skip to content

gh-115258: Fix thread queue shutdown test #115898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 106 additions & 78 deletions Lib/test/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,97 +317,125 @@ def test_shutdown_all_methods_in_one_thread(self):
def test_shutdown_immediate_all_methods_in_one_thread(self):
return self._shutdown_all_methods_in_one_thread(True)

def _write_msg_thread(self, q, n, results, delay,
i_when_exec_shutdown,
event_start, event_end):
event_start.wait()
for i in range(1, n+1):
try:
q.put((i, "YDLO"))
results.append(True)
except self.queue.ShutDown:
results.append(False)
# triggers shutdown of queue
if i == i_when_exec_shutdown:
event_end.set()
time.sleep(delay)
# end of all puts
q.join()
def _shutdown_all_methods_in_many_threads(self, immediate):
# Arrange
q = self.type2test()

start_puts = threading.Event()
start_gets = threading.Event()
put = threading.Event()

n_gets_lock = threading.Lock()
n_gets = 0

def _read_msg_thread(self, q, nb, results, delay, event_start):
event_start.wait()
block = True
while nb:
time.sleep(delay)
calls = []
results = []
queue_size_after_join = []

def _record_call(f, *a):
calls.append((f, a))
return f(*a)

def _record_result(f):
try:
# Get at least one message
q.get(block)
block = False
result = f()
except Exception as e:
results.append((f, e))
else:
results.append((f, result))

def put_worker():
start_puts.wait()

for i in range(5):
_record_call(q.put, i)

start_gets.set()

for i in range(5, 25):
put.wait()
_record_call(q.put, i)
put.clear()

_record_call(q.shutdown, immediate)

# Should raise ShutDown
_record_call(q.put, 25)

def get_worker():
nonlocal n_gets

start_gets.wait()

while True:
with n_gets_lock:
if n_gets >= 25:
break
n_gets += 1

put.set()
_record_call(q.get)

q.task_done()
results.append(True)
nb -= 1
except self.queue.ShutDown:
results.append(False)
nb -= 1
except self.queue.Empty:
pass
q.join()

def _shutdown_thread(self, q, event_end, immediate):
event_end.wait()
q.shutdown(immediate)
q.join()
put.set()
_record_call(q.get, False) # should raise ShutDown if immediate

def _join_thread(self, q, delay, event_start):
event_start.wait()
time.sleep(delay)
q.join()
def join_worker():
start_gets.wait()
_record_call(q.join)
queue_size_after_join.append(q.qsize())

def _shutdown_all_methods_in_many_threads(self, immediate):
q = self.type2test()
ps = []
ev_start = threading.Event()
ev_exec_shutdown = threading.Event()
res_puts = []
res_gets = []
delay = 1e-4
read_process = 4
nb_msgs = read_process * 16
nb_msgs_r = nb_msgs // read_process
when_exec_shutdown = nb_msgs // 2
lprocs = (
(self._write_msg_thread, 1, (q, nb_msgs, res_puts, delay,
when_exec_shutdown,
ev_start, ev_exec_shutdown)),
(self._read_msg_thread, read_process, (q, nb_msgs_r,
res_gets, delay*2,
ev_start)),
(self._join_thread, 2, (q, delay*2, ev_start)),
(self._shutdown_thread, 1, (q, ev_exec_shutdown, immediate)),
)
# start all threds
for func, n, args in lprocs:
for i in range(n):
ps.append(threading.Thread(target=func, args=args))
ps[-1].start()
# set event in order to run q.shutdown()
ev_start.set()
def _start_thread(f):
thread = threading.Thread(target=_record_result, args=(f,))
thread.start()
return thread

if not immediate:
assert(len(res_gets) == len(res_puts))
assert(res_gets.count(True) == res_puts.count(True))
else:
assert(len(res_gets) <= len(res_puts))
assert(res_gets.count(True) <= res_puts.count(True))
threads = [
_start_thread(put_worker),
*(_start_thread(get_worker) for _ in range(4)),
*(_start_thread(join_worker) for _ in range(2)),
]

for thread in ps[1:]:
# Act
start_puts.set()
for thread in threads:
thread.join()

@unittest.skip("test times out (gh-115258)")
# Assert
self.assertEqual(q.qsize(), 0)

if immediate:
self.assertTrue(all(qs > 0 for qs in queue_size_after_join))
else:
self.assertTrue(all(qs == 0 for qs in queue_size_after_join))

self.assertListEqual(
[a for f, a in calls if f is q.put], [(i,) for i in range(33)]
)
self.assertListEqual(
[a for f, a in calls if f is q.get], [(False,)] * 36
)
self.assertListEqual([a for f, a in calls if f is q.join], [(), ()])
self.assertListEqual(
[a for f, a in calls if f is q.shutdown], [immediate]
)

put_worker_result = next(r for f, r in results if f is put_worker)
self.assertIs(put_worker_result.__class__, self.queue.ShutDown)

get_worker_results = [r for f, r in results if f is get_worker]
if immediate:
self.assertListEqual(get_worker_results, [self.queue.ShutDown] * 4)
else:
self.assertListEqual(get_worker_results, [None] * 4)

join_worker_results = [r for f, r in results if f is join_worker]
self.assertListEqual(join_worker_results, [None, None])

def test_shutdown_all_methods_in_many_threads(self):
return self._shutdown_all_methods_in_many_threads(False)

@unittest.skip("test times out (gh-115258)")
def test_shutdown_immediate_all_methods_in_many_threads(self):
return self._shutdown_all_methods_in_many_threads(True)

Expand Down