1+ import typing
12from unittest import mock
23
34import opentelemetry .instrumentation .asgi as otel_asgi
5+ from opentelemetry import trace
6+ from opentelemetry .context import Context
7+ from opentelemetry .propagate import get_global_textmap , set_global_textmap
8+ from opentelemetry .propagators .textmap import (
9+ CarrierT ,
10+ Getter ,
11+ Setter ,
12+ TextMapPropagator ,
13+ default_getter ,
14+ default_setter ,
15+ )
416from opentelemetry .test .asgitestutil import AsgiTestBase
517from opentelemetry .test .test_base import TestBase
618from opentelemetry .trace import SpanKind
1325from .test_asgi_middleware import simple_asgi
1426
1527
28+ class MockTextMapPropagator (TextMapPropagator ):
29+ """Mock propagator for testing purposes using both getter `get` and `all`."""
30+
31+ TRACE_ID_KEY = "mock-traceid"
32+ SPAN_ID_KEY = "mock-spanid"
33+
34+ def extract (
35+ self ,
36+ carrier : CarrierT ,
37+ context : typing .Optional [Context ] = None ,
38+ getter : Getter = default_getter ,
39+ ) -> Context :
40+ if context is None :
41+ context = Context ()
42+
43+ trace_id_list = getter .get (carrier , self .TRACE_ID_KEY )
44+ span_id_list = getter .get (carrier , self .SPAN_ID_KEY )
45+ carrier_keys = getter .keys (carrier )
46+
47+ if not trace_id_list or not span_id_list :
48+ assert not any (key in carrier_keys for key in self .fields )
49+ return context
50+
51+ assert all (key in carrier_keys for key in self .fields )
52+ return trace .set_span_in_context (
53+ trace .NonRecordingSpan (
54+ trace .SpanContext (
55+ trace_id = int (trace_id_list [0 ]),
56+ span_id = int (span_id_list [0 ]),
57+ is_remote = True ,
58+ )
59+ ),
60+ context ,
61+ )
62+
63+ def inject (
64+ self ,
65+ carrier : CarrierT ,
66+ context : typing .Optional [Context ] = None ,
67+ setter : Setter = default_setter ,
68+ ) -> None :
69+ span = trace .get_current_span (context )
70+ setter .set (
71+ carrier , self .TRACE_ID_KEY , str (span .get_span_context ().trace_id )
72+ )
73+ setter .set (
74+ carrier , self .SPAN_ID_KEY , str (span .get_span_context ().span_id )
75+ )
76+
77+ @property
78+ def fields (self ):
79+ return {self .TRACE_ID_KEY , self .SPAN_ID_KEY }
80+
81+
1682async def http_app_with_custom_headers (scope , receive , send ):
1783 message = await receive ()
1884 assert scope ["type" ] == "http"
@@ -34,6 +100,8 @@ async def http_app_with_custom_headers(scope, receive, send):
34100 b"my-custom-regex-value-3,my-custom-regex-value-4" ,
35101 ),
36102 (b"my-secret-header" , b"my-secret-value" ),
103+ (MockTextMapPropagator .TRACE_ID_KEY .encode (), b"1" ),
104+ (MockTextMapPropagator .SPAN_ID_KEY .encode (), b"2" ),
37105 ],
38106 }
39107 )
@@ -60,6 +128,8 @@ async def websocket_app_with_custom_headers(scope, receive, send):
60128 b"my-custom-regex-value-3,my-custom-regex-value-4" ,
61129 ),
62130 (b"my-secret-header" , b"my-secret-value" ),
131+ (MockTextMapPropagator .TRACE_ID_KEY .encode (), b"1" ),
132+ (MockTextMapPropagator .SPAN_ID_KEY .encode (), b"2" ),
63133 ],
64134 }
65135 )
@@ -88,6 +158,11 @@ def setUp(self):
88158 self .app = otel_asgi .OpenTelemetryMiddleware (
89159 simple_asgi , tracer_provider = self .tracer_provider
90160 )
161+ self .previous_propagator = get_global_textmap ()
162+ set_global_textmap (MockTextMapPropagator ())
163+
164+ def tearDown (self ):
165+ set_global_textmap (self .previous_propagator )
91166
92167 def test_http_custom_request_headers_in_span_attributes (self ):
93168 self .scope ["headers" ].extend (
0 commit comments