1313# limitations under the License.
1414
1515import threading
16- from concurrent .futures import ThreadPoolExecutor , as_completed
16+ from concurrent .futures import ThreadPoolExecutor
1717from typing import List
1818
1919from opentelemetry import trace
@@ -51,11 +51,13 @@ def run_threading_test(self, thread: threading.Thread):
5151
5252 # check result
5353 self .assertEqual (len (self ._mock_span_contexts ), 1 )
54- self .assert_span_context_equality (
54+ self .assertEqual (
5555 self ._mock_span_contexts [0 ], expected_span_context
5656 )
5757
58- def test_trace_context_propagation_in_thread_pool (self ):
58+ def test_trace_context_propagation_in_thread_pool_with_multiple_workers (
59+ self ,
60+ ):
5961 max_workers = 10
6062 executor = ThreadPoolExecutor (max_workers = max_workers )
6163
@@ -65,38 +67,65 @@ def test_trace_context_propagation_in_thread_pool(self):
6567 with self ._tracer .start_as_current_span (f"trace_{ num } " ) as span :
6668 expected_span_context = span .get_span_context ()
6769 expected_span_contexts .append (expected_span_context )
68- future = executor .submit (self .fake_func )
70+ future = executor .submit (
71+ self .get_current_span_context_for_test
72+ )
6973 futures_list .append (future )
7074
71- for future in as_completed (futures_list ):
72- future .result ()
75+ result_span_contexts = [future .result () for future in futures_list ]
7376
7477 # check result
75- self .assertEqual (len (self . _mock_span_contexts ), max_workers )
78+ self .assertEqual (len (result_span_contexts ), max_workers )
7679 self .assertEqual (
77- len (self . _mock_span_contexts ), len (expected_span_contexts )
80+ len (result_span_contexts ), len (expected_span_contexts )
7881 )
79- for index , mock_span_context in enumerate (self . _mock_span_contexts ):
80- self .assert_span_context_equality (
81- mock_span_context , expected_span_contexts [index ]
82+ for index , result_span_context in enumerate (result_span_contexts ):
83+ self .assertEqual (
84+ result_span_context , expected_span_contexts [index ]
8285 )
8386
84- def fake_func (self ):
85- span_context = trace .get_current_span ().get_span_context ()
87+ def test_trace_context_propagation_in_thread_pool_with_single_worker (self ):
88+ max_workers = 1
89+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
90+ # test propagation of the same trace context across multiple tasks
91+ with self ._tracer .start_as_current_span (f"task" ) as task_span :
92+ expected_task_context = task_span .get_span_context ()
93+ future1 = executor .submit (
94+ self .get_current_span_context_for_test
95+ )
96+ future2 = executor .submit (
97+ self .get_current_span_context_for_test
98+ )
99+
100+ # check result
101+ self .assertEqual (future1 .result (), expected_task_context )
102+ self .assertEqual (future2 .result (), expected_task_context )
103+
104+ # test propagation of different trace contexts across tasks in sequence
105+ with self ._tracer .start_as_current_span (f"task1" ) as task1_span :
106+ expected_task1_context = task1_span .get_span_context ()
107+ future1 = executor .submit (
108+ self .get_current_span_context_for_test
109+ )
110+
111+ # check result
112+ self .assertEqual (future1 .result (), expected_task1_context )
113+
114+ with self ._tracer .start_as_current_span (f"task2" ) as task2_span :
115+ expected_task2_context = task2_span .get_span_context ()
116+ future2 = executor .submit (
117+ self .get_current_span_context_for_test
118+ )
119+
120+ # check result
121+ self .assertEqual (future2 .result (), expected_task2_context )
122+
123+ def fake_func (self ) -> trace .SpanContext :
124+ span_context = self .get_current_span_context_for_test ()
86125 self ._mock_span_contexts .append (span_context )
87126
88- def assert_span_context_equality (
89- self ,
90- result_span_context : trace .SpanContext ,
91- expected_span_context : trace .SpanContext ,
92- ):
93- self .assertEqual (result_span_context , expected_span_context )
94- self .assertEqual (
95- result_span_context .trace_id , expected_span_context .trace_id
96- )
97- self .assertEqual (
98- result_span_context .span_id , expected_span_context .span_id
99- )
127+ def get_current_span_context_for_test (self ) -> trace .SpanContext :
128+ return trace .get_current_span ().get_span_context ()
100129
101130 def print_square (self , num ):
102131 with self ._tracer .start_as_current_span ("square" ):
0 commit comments