32
32
from ddtrace import tracer , patch , Span
33
33
from ddtrace import __version__ as ddtrace_version
34
34
from ddtrace .propagation .http import HTTPPropagator
35
+ from ddtrace .context import Context
35
36
from datadog_lambda import __version__ as datadog_lambda_version
36
37
from datadog_lambda .trigger import (
37
38
_EventSource ,
53
54
54
55
logger = logging .getLogger (__name__ )
55
56
56
- dd_trace_context = {}
57
+ dd_trace_context = None
57
58
dd_tracing_enabled = os .environ .get ("DD_TRACE_ENABLED" , "false" ).lower () == "true"
58
59
if dd_tracing_enabled :
59
60
# Enable the telemetry client if the user has opted in
@@ -102,11 +103,11 @@ def _get_xray_trace_context():
102
103
)
103
104
if xray_trace_entity is None :
104
105
return None
105
- trace_context = {
106
- "trace-id" : _convert_xray_trace_id (xray_trace_entity .get ("trace_id" )),
107
- "parent-id" : _convert_xray_entity_id (xray_trace_entity .get ("parent_id" )),
108
- "sampling-priority" : _convert_xray_sampling (xray_trace_entity .get ("sampled" )),
109
- }
106
+ trace_context = Context (
107
+ trace_id = _convert_xray_trace_id (xray_trace_entity .get ("trace_id" )),
108
+ span_id = _convert_xray_entity_id (xray_trace_entity .get ("parent_id" )),
109
+ sampling_priority = _convert_xray_sampling (xray_trace_entity .get ("sampled" )),
110
+ )
110
111
logger .debug (
111
112
"Converted trace context %s from X-Ray segment %s" ,
112
113
trace_context ,
@@ -124,26 +125,19 @@ def _get_dd_trace_py_context():
124
125
if not span :
125
126
return None
126
127
127
- parent_id = span .context .span_id
128
- trace_id = span .context .trace_id
129
- sampling_priority = span .context .sampling_priority
130
128
logger .debug (
131
129
"found dd trace context: %s" , (span .context .trace_id , span .context .span_id )
132
130
)
133
- return {
134
- "parent-id" : str (parent_id ),
135
- "trace-id" : str (trace_id ),
136
- "sampling-priority" : str (sampling_priority ),
137
- "source" : TraceContextSource .DDTRACE ,
138
- }
131
+ return span .context
139
132
140
133
141
- def _context_obj_to_headers (obj ):
142
- return {
143
- TraceHeader .TRACE_ID : str (obj .get ("trace-id" )),
144
- TraceHeader .PARENT_ID : str (obj .get ("parent-id" )),
145
- TraceHeader .SAMPLING_PRIORITY : str (obj .get ("sampling-priority" )),
146
- }
134
+ def _is_context_complete (context ):
135
+ return (
136
+ context
137
+ and context .trace_id
138
+ and context .span_id
139
+ and context .sampling_priority is not None
140
+ )
147
141
148
142
149
143
def create_dd_dummy_metadata_subsegment (
@@ -164,28 +158,14 @@ def extract_context_from_lambda_context(lambda_context):
164
158
165
159
dd_trace libraries inject this trace context on synchronous invocations
166
160
"""
161
+ dd_data = None
167
162
client_context = lambda_context .client_context
168
- trace_id = None
169
- parent_id = None
170
- sampling_priority = None
171
163
if client_context and client_context .custom :
164
+ dd_data = client_context .custom
172
165
if "_datadog" in client_context .custom :
173
166
# Legacy trace propagation dict
174
- dd_data = client_context .custom .get ("_datadog" , {})
175
- trace_id = dd_data .get (TraceHeader .TRACE_ID )
176
- parent_id = dd_data .get (TraceHeader .PARENT_ID )
177
- sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
178
- elif (
179
- TraceHeader .TRACE_ID in client_context .custom
180
- and TraceHeader .PARENT_ID in client_context .custom
181
- and TraceHeader .SAMPLING_PRIORITY in client_context .custom
182
- ):
183
- # New trace propagation keys
184
- trace_id = client_context .custom .get (TraceHeader .TRACE_ID )
185
- parent_id = client_context .custom .get (TraceHeader .PARENT_ID )
186
- sampling_priority = client_context .custom .get (TraceHeader .SAMPLING_PRIORITY )
187
-
188
- return trace_id , parent_id , sampling_priority
167
+ dd_data = client_context .custom .get ("_datadog" )
168
+ return propagator .extract (dd_data )
189
169
190
170
191
171
def extract_context_from_http_event_or_context (
@@ -205,33 +185,17 @@ def extract_context_from_http_event_or_context(
205
185
EventTypes .API_GATEWAY , subtype = EventSubtypes .HTTP_API
206
186
)
207
187
injected_authorizer_data = get_injected_authorizer_data (event , is_http_api )
208
- if injected_authorizer_data :
209
- try :
210
- # fail fast on any KeyError here
211
- trace_id = injected_authorizer_data [TraceHeader .TRACE_ID ]
212
- parent_id = injected_authorizer_data [TraceHeader .PARENT_ID ]
213
- sampling_priority = injected_authorizer_data .get (
214
- TraceHeader .SAMPLING_PRIORITY
215
- )
216
- return trace_id , parent_id , sampling_priority
217
- except Exception as e :
218
- logger .debug (
219
- "extract_context_from_authorizer_event returned with error. \
220
- Continue without injecting the authorizer span %s" ,
221
- e ,
222
- )
188
+ context = propagator .extract (injected_authorizer_data )
189
+ if _is_context_complete (context ):
190
+ return context
223
191
224
- headers = event .get ("headers" , {}) or {}
225
- lowercase_headers = { k . lower (): v for k , v in headers . items ()}
192
+ headers = event .get ("headers" )
193
+ context = propagator . extract ( headers )
226
194
227
- trace_id = lowercase_headers .get (TraceHeader .TRACE_ID )
228
- parent_id = lowercase_headers .get (TraceHeader .PARENT_ID )
229
- sampling_priority = lowercase_headers .get (TraceHeader .SAMPLING_PRIORITY )
230
-
231
- if not trace_id or not parent_id or not sampling_priority :
195
+ if not _is_context_complete (context ):
232
196
return extract_context_from_lambda_context (lambda_context )
233
197
234
- return trace_id , parent_id , sampling_priority
198
+ return context
235
199
236
200
237
201
def create_sns_event (message ):
@@ -262,12 +226,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
262
226
263
227
# EventBridge => SQS
264
228
try :
265
- (
266
- trace_id ,
267
- parent_id ,
268
- sampling_priority ,
269
- ) = _extract_context_from_eventbridge_sqs_event (event )
270
- return trace_id , parent_id , sampling_priority
229
+ context = _extract_context_from_eventbridge_sqs_event (event )
230
+ if _is_context_complete (context ):
231
+ return context
271
232
except Exception :
272
233
logger .debug ("Failed extracting context as EventBridge to SQS." )
273
234
@@ -311,11 +272,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
311
272
"context from String or Binary SQS/SNS message attributes"
312
273
)
313
274
dd_data = json .loads (dd_json_data )
314
- trace_id = dd_data .get (TraceHeader .TRACE_ID )
315
- parent_id = dd_data .get (TraceHeader .PARENT_ID )
316
- sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
317
-
318
- return trace_id , parent_id , sampling_priority
275
+ return propagator .extract (dd_data )
319
276
except Exception as e :
320
277
logger .debug ("The trace extractor returned with error %s" , e )
321
278
return extract_context_from_lambda_context (lambda_context )
@@ -329,20 +286,12 @@ def _extract_context_from_eventbridge_sqs_event(event):
329
286
This is only possible if first record in `Records` contains a
330
287
`body` field which contains the EventBridge `detail` as a JSON string.
331
288
"""
332
- try :
333
- first_record = event .get ("Records" )[0 ]
334
- if "body" in first_record :
335
- body_str = first_record .get ("body" , {})
336
- body = json .loads (body_str )
337
-
338
- detail = body .get ("detail" )
339
- dd_context = detail .get ("_datadog" )
340
- trace_id = dd_context .get (TraceHeader .TRACE_ID )
341
- parent_id = dd_context .get (TraceHeader .PARENT_ID )
342
- sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
343
- return trace_id , parent_id , sampling_priority
344
- except Exception :
345
- raise
289
+ first_record = event .get ("Records" )[0 ]
290
+ body_str = first_record .get ("body" )
291
+ body = json .loads (body_str )
292
+ detail = body .get ("detail" )
293
+ dd_context = detail .get ("_datadog" )
294
+ return propagator .extract (dd_context )
346
295
347
296
348
297
def extract_context_from_eventbridge_event (event , lambda_context ):
@@ -355,10 +304,7 @@ def extract_context_from_eventbridge_event(event, lambda_context):
355
304
dd_context = detail .get ("_datadog" )
356
305
if not dd_context :
357
306
return extract_context_from_lambda_context (lambda_context )
358
- trace_id = dd_context .get (TraceHeader .TRACE_ID )
359
- parent_id = dd_context .get (TraceHeader .PARENT_ID )
360
- sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
361
- return trace_id , parent_id , sampling_priority
307
+ return propagator .extract (dd_context )
362
308
except Exception as e :
363
309
logger .debug ("The trace extractor returned with error %s" , e )
364
310
return extract_context_from_lambda_context (lambda_context )
@@ -381,10 +327,7 @@ def extract_context_from_kinesis_event(event, lambda_context):
381
327
if not dd_ctx :
382
328
return extract_context_from_lambda_context (lambda_context )
383
329
384
- trace_id = dd_ctx .get (TraceHeader .TRACE_ID )
385
- parent_id = dd_ctx .get (TraceHeader .PARENT_ID )
386
- sampling_priority = dd_ctx .get (TraceHeader .SAMPLING_PRIORITY )
387
- return trace_id , parent_id , sampling_priority
330
+ return propagator .extract (dd_ctx )
388
331
except Exception as e :
389
332
logger .debug ("The trace extractor returned with error %s" , e )
390
333
return extract_context_from_lambda_context (lambda_context )
@@ -417,7 +360,9 @@ def extract_context_from_step_functions(event, lambda_context):
417
360
execution_id + "#" + state_name + "#" + state_entered_time
418
361
)
419
362
sampling_priority = SamplingPriority .AUTO_KEEP
420
- return trace_id , parent_id , sampling_priority
363
+ return Context (
364
+ trace_id = trace_id , span_id = parent_id , sampling_priority = sampling_priority
365
+ )
421
366
except Exception as e :
422
367
logger .debug ("The Step Functions trace extractor returned with error %s" , e )
423
368
return extract_context_from_lambda_context (lambda_context )
@@ -433,12 +378,12 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
433
378
parent_id ,
434
379
sampling_priority ,
435
380
) = extractor (event , lambda_context )
436
- return trace_id , parent_id , sampling_priority
381
+ return Context (
382
+ trace_id = trace_id , span_id = parent_id , sampling_priority = sampling_priority
383
+ )
437
384
except Exception as e :
438
385
logger .debug ("The trace extractor returned with error %s" , e )
439
386
440
- return None , None , None
441
-
442
387
443
388
def is_authorizer_response (response ) -> bool :
444
389
try :
@@ -504,56 +449,27 @@ def extract_dd_trace_context(
504
449
event_source = parse_event_source (event )
505
450
506
451
if extractor is not None :
507
- (
508
- trace_id ,
509
- parent_id ,
510
- sampling_priority ,
511
- ) = extract_context_custom_extractor (extractor , event , lambda_context )
452
+ context = extract_context_custom_extractor (extractor , event , lambda_context )
512
453
elif isinstance (event , (set , dict )) and "headers" in event :
513
- (
514
- trace_id ,
515
- parent_id ,
516
- sampling_priority ,
517
- ) = extract_context_from_http_event_or_context (
454
+ context = extract_context_from_http_event_or_context (
518
455
event , lambda_context , event_source , decode_authorizer_context
519
456
)
520
457
elif event_source .equals (EventTypes .SNS ) or event_source .equals (EventTypes .SQS ):
521
- (
522
- trace_id ,
523
- parent_id ,
524
- sampling_priority ,
525
- ) = extract_context_from_sqs_or_sns_event_or_context (event , lambda_context )
458
+ context = extract_context_from_sqs_or_sns_event_or_context (
459
+ event , lambda_context
460
+ )
526
461
elif event_source .equals (EventTypes .EVENTBRIDGE ):
527
- (
528
- trace_id ,
529
- parent_id ,
530
- sampling_priority ,
531
- ) = extract_context_from_eventbridge_event (event , lambda_context )
462
+ context = extract_context_from_eventbridge_event (event , lambda_context )
532
463
elif event_source .equals (EventTypes .KINESIS ):
533
- (
534
- trace_id ,
535
- parent_id ,
536
- sampling_priority ,
537
- ) = extract_context_from_kinesis_event (event , lambda_context )
464
+ context = extract_context_from_kinesis_event (event , lambda_context )
538
465
elif event_source .equals (EventTypes .STEPFUNCTIONS ):
539
- (
540
- trace_id ,
541
- parent_id ,
542
- sampling_priority ,
543
- ) = extract_context_from_step_functions (event , lambda_context )
466
+ context = extract_context_from_step_functions (event , lambda_context )
544
467
else :
545
- trace_id , parent_id , sampling_priority = extract_context_from_lambda_context (
546
- lambda_context
547
- )
468
+ context = extract_context_from_lambda_context (lambda_context )
548
469
549
- if trace_id and parent_id and sampling_priority :
470
+ if _is_context_complete ( context ) :
550
471
logger .debug ("Extracted Datadog trace context from event or context" )
551
- metadata = {
552
- "trace-id" : trace_id ,
553
- "parent-id" : parent_id ,
554
- "sampling-priority" : sampling_priority ,
555
- }
556
- dd_trace_context = metadata .copy ()
472
+ dd_trace_context = context
557
473
trace_context_source = TraceContextSource .EVENT
558
474
else :
559
475
# AWS Lambda runtime caches global variables between invocations,
@@ -579,8 +495,8 @@ def get_dd_trace_context():
579
495
"""
580
496
if dd_tracing_enabled :
581
497
dd_trace_py_context = _get_dd_trace_py_context ()
582
- if dd_trace_py_context is not None :
583
- return _context_obj_to_headers ( dd_trace_py_context )
498
+ if _is_context_complete ( dd_trace_py_context ) :
499
+ return dd_trace_py_context
584
500
585
501
global dd_trace_context
586
502
@@ -592,16 +508,17 @@ def get_dd_trace_context():
592
508
% e
593
509
)
594
510
if not xray_context :
595
- return {}
596
-
597
- if not dd_trace_context :
598
- return _context_obj_to_headers (xray_context )
511
+ return None
599
512
600
- context = dd_trace_context .copy ()
601
- context ["parent-id" ] = xray_context .get ("parent-id" )
602
- logger .debug ("Set parent id from xray trace context: %s" , context .get ("parent-id" ))
513
+ if not _is_context_complete (dd_trace_context ):
514
+ return xray_context
603
515
604
- return _context_obj_to_headers (context )
516
+ logger .debug ("Set parent id from xray trace context: %s" , xray_context .span_id )
517
+ return Context (
518
+ trace_id = dd_trace_context .trace_id ,
519
+ span_id = xray_context .span_id ,
520
+ sampling_priority = dd_trace_context .sampling_priority ,
521
+ )
605
522
606
523
607
524
def set_correlation_ids ():
@@ -620,13 +537,11 @@ def set_correlation_ids():
620
537
return
621
538
622
539
context = get_dd_trace_context ()
623
- if not context :
540
+ if not _is_context_complete ( context ) :
624
541
return
625
542
626
- span = tracer .trace ("dummy.span" )
627
- span .trace_id = int (context [TraceHeader .TRACE_ID ])
628
- span .span_id = int (context [TraceHeader .PARENT_ID ])
629
-
543
+ tracer .context_provider .activate (context )
544
+ tracer .trace ("dummy.span" )
630
545
logger .debug ("correlation ids set" )
631
546
632
547
@@ -669,18 +584,20 @@ def is_lambda_context():
669
584
670
585
def set_dd_trace_py_root (trace_context_source , merge_xray_traces ):
671
586
if trace_context_source == TraceContextSource .EVENT or merge_xray_traces :
672
- context = dict (dd_trace_context )
587
+ context = Context (
588
+ trace_id = dd_trace_context .trace_id ,
589
+ span_id = dd_trace_context .span_id ,
590
+ sampling_priority = dd_trace_context .sampling_priority ,
591
+ )
673
592
if merge_xray_traces :
674
593
xray_context = _get_xray_trace_context ()
675
- if xray_context is not None :
676
- context [ "parent-id" ] = xray_context .get ( "parent-id" )
594
+ if xray_context . span_id :
595
+ context . span_id = xray_context .span_id
677
596
678
- headers = _context_obj_to_headers (context )
679
- span_context = propagator .extract (headers )
680
- tracer .context_provider .activate (span_context )
597
+ tracer .context_provider .activate (context )
681
598
logger .debug (
682
599
"Set dd trace root context to: %s" ,
683
- (span_context .trace_id , span_context .span_id ),
600
+ (context .trace_id , context .span_id ),
684
601
)
685
602
686
603
0 commit comments