@@ -128,6 +128,7 @@ class will essentially render the service "deaf" while one request is
128
128
import os
129
129
import sys
130
130
import threading
131
+ import contextlib
131
132
from io import BufferedIOBase
132
133
from time import monotonic as time
133
134
@@ -628,6 +629,55 @@ def server_close(self):
628
629
self .collect_children (blocking = self .block_on_close )
629
630
630
631
632
+ class _Threads (list ):
633
+ """
634
+ Joinable list of all non-daemon threads.
635
+ """
636
+ def __init__ (self ):
637
+ self ._lock = threading .Lock ()
638
+
639
+ def append (self , thread ):
640
+ if thread .daemon :
641
+ return
642
+ with self ._lock :
643
+ super ().append (thread )
644
+
645
+ def remove (self , thread ):
646
+ with self ._lock :
647
+ # should not happen, but safe to ignore
648
+ with contextlib .suppress (ValueError ):
649
+ super ().remove (thread )
650
+
651
+ def remove_current (self ):
652
+ """Remove a current non-daemon thread."""
653
+ thread = threading .current_thread ()
654
+ if not thread .daemon :
655
+ self .remove (thread )
656
+
657
+ def pop_all (self ):
658
+ with self ._lock :
659
+ self [:], result = [], self [:]
660
+ return result
661
+
662
+ def join (self ):
663
+ for thread in self .pop_all ():
664
+ thread .join ()
665
+
666
+
667
+ class _NoThreads :
668
+ """
669
+ Degenerate version of _Threads.
670
+ """
671
+ def append (self , thread ):
672
+ pass
673
+
674
+ def join (self ):
675
+ pass
676
+
677
+ def remove_current (self ):
678
+ pass
679
+
680
+
631
681
class ThreadingMixIn :
632
682
"""Mix-in class to handle each request in a new thread."""
633
683
@@ -636,9 +686,9 @@ class ThreadingMixIn:
636
686
daemon_threads = False
637
687
# If true, server_close() waits until all non-daemonic threads terminate.
638
688
block_on_close = True
639
- # For non-daemonic threads, list of threading.Threading objects
689
+ # Threads object
640
690
# used by server_close() to wait for all threads completion.
641
- _threads = None
691
+ _threads = _NoThreads ()
642
692
643
693
def process_request_thread (self , request , client_address ):
644
694
"""Same as in BaseServer but as a thread.
@@ -651,27 +701,24 @@ def process_request_thread(self, request, client_address):
651
701
except Exception :
652
702
self .handle_error (request , client_address )
653
703
finally :
654
- self .shutdown_request (request )
704
+ try :
705
+ self .shutdown_request (request )
706
+ finally :
707
+ self ._threads .remove_current ()
655
708
656
709
def process_request (self , request , client_address ):
657
710
"""Start a new thread to process the request."""
711
+ if self .block_on_close :
712
+ vars (self ).setdefault ('_threads' , _Threads ())
658
713
t = threading .Thread (target = self .process_request_thread ,
659
714
args = (request , client_address ))
660
715
t .daemon = self .daemon_threads
661
- if not t .daemon and self .block_on_close :
662
- if self ._threads is None :
663
- self ._threads = []
664
- self ._threads .append (t )
716
+ self ._threads .append (t )
665
717
t .start ()
666
718
667
719
def server_close (self ):
668
720
super ().server_close ()
669
- if self .block_on_close :
670
- threads = self ._threads
671
- self ._threads = None
672
- if threads :
673
- for thread in threads :
674
- thread .join ()
721
+ self ._threads .join ()
675
722
676
723
677
724
if hasattr (os , "fork" ):
0 commit comments