Skip to content

Commit 2153fbf

Browse files
committed
Defer trace context extraction to ddtrace.
1 parent aa5a1c9 commit 2153fbf

File tree

2 files changed

+207
-326
lines changed

2 files changed

+207
-326
lines changed

datadog_lambda/tracing.py

Lines changed: 78 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ddtrace import tracer, patch, Span
3333
from ddtrace import __version__ as ddtrace_version
3434
from ddtrace.propagation.http import HTTPPropagator
35+
from ddtrace.context import Context
3536
from datadog_lambda import __version__ as datadog_lambda_version
3637
from datadog_lambda.trigger import (
3738
_EventSource,
@@ -53,7 +54,7 @@
5354

5455
logger = logging.getLogger(__name__)
5556

56-
dd_trace_context = {}
57+
dd_trace_context = None
5758
dd_tracing_enabled = os.environ.get("DD_TRACE_ENABLED", "false").lower() == "true"
5859
if dd_tracing_enabled:
5960
# Enable the telemetry client if the user has opted in
@@ -102,11 +103,11 @@ def _get_xray_trace_context():
102103
)
103104
if xray_trace_entity is None:
104105
return None
105-
trace_context = {
106-
"trace-id": _convert_xray_trace_id(xray_trace_entity.get("trace_id")),
107-
"parent-id": _convert_xray_entity_id(xray_trace_entity.get("parent_id")),
108-
"sampling-priority": _convert_xray_sampling(xray_trace_entity.get("sampled")),
109-
}
106+
trace_context = Context(
107+
trace_id=_convert_xray_trace_id(xray_trace_entity.get("trace_id")),
108+
span_id=_convert_xray_entity_id(xray_trace_entity.get("parent_id")),
109+
sampling_priority=_convert_xray_sampling(xray_trace_entity.get("sampled")),
110+
)
110111
logger.debug(
111112
"Converted trace context %s from X-Ray segment %s",
112113
trace_context,
@@ -124,26 +125,19 @@ def _get_dd_trace_py_context():
124125
if not span:
125126
return None
126127

127-
parent_id = span.context.span_id
128-
trace_id = span.context.trace_id
129-
sampling_priority = span.context.sampling_priority
130128
logger.debug(
131129
"found dd trace context: %s", (span.context.trace_id, span.context.span_id)
132130
)
133-
return {
134-
"parent-id": str(parent_id),
135-
"trace-id": str(trace_id),
136-
"sampling-priority": str(sampling_priority),
137-
"source": TraceContextSource.DDTRACE,
138-
}
131+
return span.context
139132

140133

141-
def _context_obj_to_headers(obj):
142-
return {
143-
TraceHeader.TRACE_ID: str(obj.get("trace-id")),
144-
TraceHeader.PARENT_ID: str(obj.get("parent-id")),
145-
TraceHeader.SAMPLING_PRIORITY: str(obj.get("sampling-priority")),
146-
}
134+
def _is_context_complete(context):
135+
return (
136+
context
137+
and context.trace_id
138+
and context.span_id
139+
and context.sampling_priority is not None
140+
)
147141

148142

149143
def create_dd_dummy_metadata_subsegment(
@@ -164,28 +158,14 @@ def extract_context_from_lambda_context(lambda_context):
164158
165159
dd_trace libraries inject this trace context on synchronous invocations
166160
"""
161+
dd_data = None
167162
client_context = lambda_context.client_context
168-
trace_id = None
169-
parent_id = None
170-
sampling_priority = None
171163
if client_context and client_context.custom:
164+
dd_data = client_context.custom
172165
if "_datadog" in client_context.custom:
173166
# Legacy trace propagation dict
174-
dd_data = client_context.custom.get("_datadog", {})
175-
trace_id = dd_data.get(TraceHeader.TRACE_ID)
176-
parent_id = dd_data.get(TraceHeader.PARENT_ID)
177-
sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY)
178-
elif (
179-
TraceHeader.TRACE_ID in client_context.custom
180-
and TraceHeader.PARENT_ID in client_context.custom
181-
and TraceHeader.SAMPLING_PRIORITY in client_context.custom
182-
):
183-
# New trace propagation keys
184-
trace_id = client_context.custom.get(TraceHeader.TRACE_ID)
185-
parent_id = client_context.custom.get(TraceHeader.PARENT_ID)
186-
sampling_priority = client_context.custom.get(TraceHeader.SAMPLING_PRIORITY)
187-
188-
return trace_id, parent_id, sampling_priority
167+
dd_data = client_context.custom.get("_datadog")
168+
return propagator.extract(dd_data)
189169

190170

191171
def extract_context_from_http_event_or_context(
@@ -205,33 +185,17 @@ def extract_context_from_http_event_or_context(
205185
EventTypes.API_GATEWAY, subtype=EventSubtypes.HTTP_API
206186
)
207187
injected_authorizer_data = get_injected_authorizer_data(event, is_http_api)
208-
if injected_authorizer_data:
209-
try:
210-
# fail fast on any KeyError here
211-
trace_id = injected_authorizer_data[TraceHeader.TRACE_ID]
212-
parent_id = injected_authorizer_data[TraceHeader.PARENT_ID]
213-
sampling_priority = injected_authorizer_data.get(
214-
TraceHeader.SAMPLING_PRIORITY
215-
)
216-
return trace_id, parent_id, sampling_priority
217-
except Exception as e:
218-
logger.debug(
219-
"extract_context_from_authorizer_event returned with error. \
220-
Continue without injecting the authorizer span %s",
221-
e,
222-
)
188+
context = propagator.extract(injected_authorizer_data)
189+
if _is_context_complete(context):
190+
return context
223191

224-
headers = event.get("headers", {}) or {}
225-
lowercase_headers = {k.lower(): v for k, v in headers.items()}
192+
headers = event.get("headers")
193+
context = propagator.extract(headers)
226194

227-
trace_id = lowercase_headers.get(TraceHeader.TRACE_ID)
228-
parent_id = lowercase_headers.get(TraceHeader.PARENT_ID)
229-
sampling_priority = lowercase_headers.get(TraceHeader.SAMPLING_PRIORITY)
230-
231-
if not trace_id or not parent_id or not sampling_priority:
195+
if not _is_context_complete(context):
232196
return extract_context_from_lambda_context(lambda_context)
233197

234-
return trace_id, parent_id, sampling_priority
198+
return context
235199

236200

237201
def create_sns_event(message):
@@ -262,12 +226,9 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
262226

263227
# EventBridge => SQS
264228
try:
265-
(
266-
trace_id,
267-
parent_id,
268-
sampling_priority,
269-
) = _extract_context_from_eventbridge_sqs_event(event)
270-
return trace_id, parent_id, sampling_priority
229+
context = _extract_context_from_eventbridge_sqs_event(event)
230+
if _is_context_complete(context):
231+
return context
271232
except Exception:
272233
logger.debug("Failed extracting context as EventBridge to SQS.")
273234

@@ -311,11 +272,7 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
311272
"context from String or Binary SQS/SNS message attributes"
312273
)
313274
dd_data = json.loads(dd_json_data)
314-
trace_id = dd_data.get(TraceHeader.TRACE_ID)
315-
parent_id = dd_data.get(TraceHeader.PARENT_ID)
316-
sampling_priority = dd_data.get(TraceHeader.SAMPLING_PRIORITY)
317-
318-
return trace_id, parent_id, sampling_priority
275+
return propagator.extract(dd_data)
319276
except Exception as e:
320277
logger.debug("The trace extractor returned with error %s", e)
321278
return extract_context_from_lambda_context(lambda_context)
@@ -329,20 +286,12 @@ def _extract_context_from_eventbridge_sqs_event(event):
329286
This is only possible if first record in `Records` contains a
330287
`body` field which contains the EventBridge `detail` as a JSON string.
331288
"""
332-
try:
333-
first_record = event.get("Records")[0]
334-
if "body" in first_record:
335-
body_str = first_record.get("body", {})
336-
body = json.loads(body_str)
337-
338-
detail = body.get("detail")
339-
dd_context = detail.get("_datadog")
340-
trace_id = dd_context.get(TraceHeader.TRACE_ID)
341-
parent_id = dd_context.get(TraceHeader.PARENT_ID)
342-
sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY)
343-
return trace_id, parent_id, sampling_priority
344-
except Exception:
345-
raise
289+
first_record = event.get("Records")[0]
290+
body_str = first_record.get("body")
291+
body = json.loads(body_str)
292+
detail = body.get("detail")
293+
dd_context = detail.get("_datadog")
294+
return propagator.extract(dd_context)
346295

347296

348297
def extract_context_from_eventbridge_event(event, lambda_context):
@@ -355,10 +304,7 @@ def extract_context_from_eventbridge_event(event, lambda_context):
355304
dd_context = detail.get("_datadog")
356305
if not dd_context:
357306
return extract_context_from_lambda_context(lambda_context)
358-
trace_id = dd_context.get(TraceHeader.TRACE_ID)
359-
parent_id = dd_context.get(TraceHeader.PARENT_ID)
360-
sampling_priority = dd_context.get(TraceHeader.SAMPLING_PRIORITY)
361-
return trace_id, parent_id, sampling_priority
307+
return propagator.extract(dd_context)
362308
except Exception as e:
363309
logger.debug("The trace extractor returned with error %s", e)
364310
return extract_context_from_lambda_context(lambda_context)
@@ -381,10 +327,7 @@ def extract_context_from_kinesis_event(event, lambda_context):
381327
if not dd_ctx:
382328
return extract_context_from_lambda_context(lambda_context)
383329

384-
trace_id = dd_ctx.get(TraceHeader.TRACE_ID)
385-
parent_id = dd_ctx.get(TraceHeader.PARENT_ID)
386-
sampling_priority = dd_ctx.get(TraceHeader.SAMPLING_PRIORITY)
387-
return trace_id, parent_id, sampling_priority
330+
return propagator.extract(dd_ctx)
388331
except Exception as e:
389332
logger.debug("The trace extractor returned with error %s", e)
390333
return extract_context_from_lambda_context(lambda_context)
@@ -417,7 +360,9 @@ def extract_context_from_step_functions(event, lambda_context):
417360
execution_id + "#" + state_name + "#" + state_entered_time
418361
)
419362
sampling_priority = SamplingPriority.AUTO_KEEP
420-
return trace_id, parent_id, sampling_priority
363+
return Context(
364+
trace_id=trace_id, span_id=parent_id, sampling_priority=sampling_priority
365+
)
421366
except Exception as e:
422367
logger.debug("The Step Functions trace extractor returned with error %s", e)
423368
return extract_context_from_lambda_context(lambda_context)
@@ -433,12 +378,12 @@ def extract_context_custom_extractor(extractor, event, lambda_context):
433378
parent_id,
434379
sampling_priority,
435380
) = extractor(event, lambda_context)
436-
return trace_id, parent_id, sampling_priority
381+
return Context(
382+
trace_id=trace_id, span_id=parent_id, sampling_priority=sampling_priority
383+
)
437384
except Exception as e:
438385
logger.debug("The trace extractor returned with error %s", e)
439386

440-
return None, None, None
441-
442387

443388
def is_authorizer_response(response) -> bool:
444389
try:
@@ -504,56 +449,27 @@ def extract_dd_trace_context(
504449
event_source = parse_event_source(event)
505450

506451
if extractor is not None:
507-
(
508-
trace_id,
509-
parent_id,
510-
sampling_priority,
511-
) = extract_context_custom_extractor(extractor, event, lambda_context)
452+
context = extract_context_custom_extractor(extractor, event, lambda_context)
512453
elif isinstance(event, (set, dict)) and "headers" in event:
513-
(
514-
trace_id,
515-
parent_id,
516-
sampling_priority,
517-
) = extract_context_from_http_event_or_context(
454+
context = extract_context_from_http_event_or_context(
518455
event, lambda_context, event_source, decode_authorizer_context
519456
)
520457
elif event_source.equals(EventTypes.SNS) or event_source.equals(EventTypes.SQS):
521-
(
522-
trace_id,
523-
parent_id,
524-
sampling_priority,
525-
) = extract_context_from_sqs_or_sns_event_or_context(event, lambda_context)
458+
context = extract_context_from_sqs_or_sns_event_or_context(
459+
event, lambda_context
460+
)
526461
elif event_source.equals(EventTypes.EVENTBRIDGE):
527-
(
528-
trace_id,
529-
parent_id,
530-
sampling_priority,
531-
) = extract_context_from_eventbridge_event(event, lambda_context)
462+
context = extract_context_from_eventbridge_event(event, lambda_context)
532463
elif event_source.equals(EventTypes.KINESIS):
533-
(
534-
trace_id,
535-
parent_id,
536-
sampling_priority,
537-
) = extract_context_from_kinesis_event(event, lambda_context)
464+
context = extract_context_from_kinesis_event(event, lambda_context)
538465
elif event_source.equals(EventTypes.STEPFUNCTIONS):
539-
(
540-
trace_id,
541-
parent_id,
542-
sampling_priority,
543-
) = extract_context_from_step_functions(event, lambda_context)
466+
context = extract_context_from_step_functions(event, lambda_context)
544467
else:
545-
trace_id, parent_id, sampling_priority = extract_context_from_lambda_context(
546-
lambda_context
547-
)
468+
context = extract_context_from_lambda_context(lambda_context)
548469

549-
if trace_id and parent_id and sampling_priority:
470+
if _is_context_complete(context):
550471
logger.debug("Extracted Datadog trace context from event or context")
551-
metadata = {
552-
"trace-id": trace_id,
553-
"parent-id": parent_id,
554-
"sampling-priority": sampling_priority,
555-
}
556-
dd_trace_context = metadata.copy()
472+
dd_trace_context = context
557473
trace_context_source = TraceContextSource.EVENT
558474
else:
559475
# AWS Lambda runtime caches global variables between invocations,
@@ -579,8 +495,8 @@ def get_dd_trace_context():
579495
"""
580496
if dd_tracing_enabled:
581497
dd_trace_py_context = _get_dd_trace_py_context()
582-
if dd_trace_py_context is not None:
583-
return _context_obj_to_headers(dd_trace_py_context)
498+
if _is_context_complete(dd_trace_py_context):
499+
return dd_trace_py_context
584500

585501
global dd_trace_context
586502

@@ -592,16 +508,17 @@ def get_dd_trace_context():
592508
% e
593509
)
594510
if not xray_context:
595-
return {}
596-
597-
if not dd_trace_context:
598-
return _context_obj_to_headers(xray_context)
511+
return None
599512

600-
context = dd_trace_context.copy()
601-
context["parent-id"] = xray_context.get("parent-id")
602-
logger.debug("Set parent id from xray trace context: %s", context.get("parent-id"))
513+
if not _is_context_complete(dd_trace_context):
514+
return xray_context
603515

604-
return _context_obj_to_headers(context)
516+
logger.debug("Set parent id from xray trace context: %s", xray_context.span_id)
517+
return Context(
518+
trace_id=dd_trace_context.trace_id,
519+
span_id=xray_context.span_id,
520+
sampling_priority=dd_trace_context.sampling_priority,
521+
)
605522

606523

607524
def set_correlation_ids():
@@ -620,13 +537,11 @@ def set_correlation_ids():
620537
return
621538

622539
context = get_dd_trace_context()
623-
if not context:
540+
if not _is_context_complete(context):
624541
return
625542

626-
span = tracer.trace("dummy.span")
627-
span.trace_id = int(context[TraceHeader.TRACE_ID])
628-
span.span_id = int(context[TraceHeader.PARENT_ID])
629-
543+
tracer.context_provider.activate(context)
544+
tracer.trace("dummy.span")
630545
logger.debug("correlation ids set")
631546

632547

@@ -669,18 +584,20 @@ def is_lambda_context():
669584

670585
def set_dd_trace_py_root(trace_context_source, merge_xray_traces):
671586
if trace_context_source == TraceContextSource.EVENT or merge_xray_traces:
672-
context = dict(dd_trace_context)
587+
context = Context(
588+
trace_id=dd_trace_context.trace_id,
589+
span_id=dd_trace_context.span_id,
590+
sampling_priority=dd_trace_context.sampling_priority,
591+
)
673592
if merge_xray_traces:
674593
xray_context = _get_xray_trace_context()
675-
if xray_context is not None:
676-
context["parent-id"] = xray_context.get("parent-id")
594+
if xray_context.span_id:
595+
context.span_id = xray_context.span_id
677596

678-
headers = _context_obj_to_headers(context)
679-
span_context = propagator.extract(headers)
680-
tracer.context_provider.activate(span_context)
597+
tracer.context_provider.activate(context)
681598
logger.debug(
682599
"Set dd trace root context to: %s",
683-
(span_context.trace_id, span_context.span_id),
600+
(context.trace_id, context.span_id),
684601
)
685602

686603

0 commit comments

Comments
 (0)