@@ -128,6 +128,7 @@ class will essentially render the service "deaf" while one request is
128128import os
129129import sys
130130import threading
131+ import contextlib
131132from io import BufferedIOBase
132133from time import monotonic as time
133134
@@ -628,6 +629,55 @@ def server_close(self):
628629 self .collect_children (blocking = self .block_on_close )
629630
630631
632+ class _Threads (list ):
633+ """
634+ Joinable list of all non-daemon threads.
635+ """
636+ def __init__ (self ):
637+ self ._lock = threading .Lock ()
638+
639+ def append (self , thread ):
640+ if thread .daemon :
641+ return
642+ with self ._lock :
643+ super ().append (thread )
644+
645+ def remove (self , thread ):
646+ with self ._lock :
647+ # should not happen, but safe to ignore
648+ with contextlib .suppress (ValueError ):
649+ super ().remove (thread )
650+
651+ def remove_current (self ):
652+ """Remove a current non-daemon thread."""
653+ thread = threading .current_thread ()
654+ if not thread .daemon :
655+ self .remove (thread )
656+
657+ def pop_all (self ):
658+ with self ._lock :
659+ self [:], result = [], self [:]
660+ return result
661+
662+ def join (self ):
663+ for thread in self .pop_all ():
664+ thread .join ()
665+
666+
667+ class _NoThreads :
668+ """
669+ Degenerate version of _Threads.
670+ """
671+ def append (self , thread ):
672+ pass
673+
674+ def join (self ):
675+ pass
676+
677+ def remove_current (self ):
678+ pass
679+
680+
631681class ThreadingMixIn :
632682 """Mix-in class to handle each request in a new thread."""
633683
@@ -636,9 +686,9 @@ class ThreadingMixIn:
636686 daemon_threads = False
637687 # If true, server_close() waits until all non-daemonic threads terminate.
638688 block_on_close = True
639- # For non-daemonic threads, list of threading.Threading objects
689+ # Threads object
640690 # used by server_close() to wait for all threads completion.
641- _threads = None
691+ _threads = _NoThreads ()
642692
643693 def process_request_thread (self , request , client_address ):
644694 """Same as in BaseServer but as a thread.
@@ -651,27 +701,24 @@ def process_request_thread(self, request, client_address):
651701 except Exception :
652702 self .handle_error (request , client_address )
653703 finally :
654- self .shutdown_request (request )
704+ try :
705+ self .shutdown_request (request )
706+ finally :
707+ self ._threads .remove_current ()
655708
656709 def process_request (self , request , client_address ):
657710 """Start a new thread to process the request."""
711+ if self .block_on_close :
712+ vars (self ).setdefault ('_threads' , _Threads ())
658713 t = threading .Thread (target = self .process_request_thread ,
659714 args = (request , client_address ))
660715 t .daemon = self .daemon_threads
661- if not t .daemon and self .block_on_close :
662- if self ._threads is None :
663- self ._threads = []
664- self ._threads .append (t )
716+ self ._threads .append (t )
665717 t .start ()
666718
667719 def server_close (self ):
668720 super ().server_close ()
669- if self .block_on_close :
670- threads = self ._threads
671- self ._threads = None
672- if threads :
673- for thread in threads :
674- thread .join ()
721+ self ._threads .join ()
675722
676723
677724if hasattr (os , "fork" ):
0 commit comments