6
6
import logging
7
7
import os
8
8
import json
9
+ import base64
9
10
from datetime import datetime , timezone
10
11
from typing import Optional , Dict
11
12
@@ -184,23 +185,98 @@ def extract_context_from_http_event_or_context(event, lambda_context):
184
185
return trace_id , parent_id , sampling_priority
185
186
186
187
187
- def extract_context_from_sqs_event_or_context (event , lambda_context ):
188
+ def create_sns_event (message ):
189
+ return {
190
+ "Records" : [
191
+ {
192
+ "EventSource" : "aws:sns" ,
193
+ "EventVersion" : "1.0" ,
194
+ "Sns" : message ,
195
+ }
196
+ ]
197
+ }
198
+
199
+
200
+ def extract_context_from_sqs_or_sns_event_or_context (event , lambda_context ):
188
201
"""
189
202
Extract Datadog trace context from the first SQS message attributes.
190
203
191
204
Falls back to lambda context if no trace data is found in the SQS message attributes.
192
205
"""
193
206
try :
194
207
first_record = event ["Records" ][0 ]
195
- msg_attributes = first_record .get ("messageAttributes" , {})
196
- dd_json_data = msg_attributes .get ("_datadog" , {}).get ("stringValue" , r"{}" )
208
+
209
+ # logic to deal with SNS => SQS event
210
+ if "body" in first_record :
211
+ body_str = first_record .get ("body" , {})
212
+ try :
213
+ body = json .loads (body_str )
214
+ if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
215
+ logger .debug ("Found SNS message inside SQS event" )
216
+ first_record = get_first_record (create_sns_event (body ))
217
+ except Exception :
218
+ first_record = event ["Records" ][0 ]
219
+ pass
220
+
221
+ msg_attributes = first_record .get (
222
+ "messageAttributes" ,
223
+ first_record .get ("Sns" , {}).get ("MessageAttributes" , {}),
224
+ )
225
+ dd_payload = msg_attributes .get ("_datadog" , {})
226
+ dd_json_data = dd_payload .get ("stringValue" , dd_payload .get ("Value" , r"{}" ))
197
227
dd_data = json .loads (dd_json_data )
198
228
trace_id = dd_data .get (TraceHeader .TRACE_ID )
199
229
parent_id = dd_data .get (TraceHeader .PARENT_ID )
200
230
sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
201
231
202
232
return trace_id , parent_id , sampling_priority
203
- except Exception :
233
+ except Exception as e :
234
+ logger .debug ("The trace extractor returned with error %s" , e )
235
+ return extract_context_from_lambda_context (lambda_context )
236
+
237
+
238
+ def extract_context_from_eventbridge_event (event , lambda_context ):
239
+ """
240
+ Extract datadog trace context from an EventBridge message's Details.
241
+ Details is often a weirdly escaped almost-JSON string. Here we have to correct for that.
242
+ """
243
+ try :
244
+ detail = event ["detail" ]
245
+ dd_context = detail .get ("_datadog" )
246
+ if not dd_context :
247
+ return extract_context_from_lambda_context (lambda_context )
248
+ trace_id = dd_context .get (TraceHeader .TRACE_ID )
249
+ parent_id = dd_context .get (TraceHeader .PARENT_ID )
250
+ sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
251
+ return trace_id , parent_id , sampling_priority
252
+ except Exception as e :
253
+ logger .debug ("The trace extractor returned with error %s" , e )
254
+ return extract_context_from_lambda_context (lambda_context )
255
+
256
+
257
+ def extract_context_from_kinesis_event (event , lambda_context ):
258
+ """
259
+ Extract datadog trace context from a Kinesis Stream's base64 encoded data string
260
+ """
261
+ try :
262
+ record = get_first_record (event )
263
+ data = record .get ("kinesis" , {}).get ("data" , None )
264
+ if data :
265
+ b64_bytes = data .encode ("ascii" )
266
+ str_bytes = base64 .b64decode (b64_bytes )
267
+ data_str = str_bytes .decode ("ascii" )
268
+ data_obj = json .loads (data_str )
269
+ dd_ctx = data_obj .get ("_datadog" )
270
+
271
+ if not dd_ctx :
272
+ return extract_context_from_lambda_context (lambda_context )
273
+
274
+ trace_id = dd_ctx .get (TraceHeader .TRACE_ID )
275
+ parent_id = dd_ctx .get (TraceHeader .PARENT_ID )
276
+ sampling_priority = dd_ctx .get (TraceHeader .SAMPLING_PRIORITY )
277
+ return trace_id , parent_id , sampling_priority
278
+ except Exception as e :
279
+ logger .debug ("The trace extractor returned with error %s" , e )
204
280
return extract_context_from_lambda_context (lambda_context )
205
281
206
282
@@ -230,6 +306,7 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
230
306
"""
231
307
global dd_trace_context
232
308
trace_context_source = None
309
+ event_source = parse_event_source (event )
233
310
234
311
if extractor is not None :
235
312
(
@@ -243,12 +320,24 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
243
320
parent_id ,
244
321
sampling_priority ,
245
322
) = extract_context_from_http_event_or_context (event , lambda_context )
246
- elif "Records" in event :
323
+ elif event_source . equals ( EventTypes . SNS ) or event_source . equals ( EventTypes . SQS ) :
247
324
(
248
325
trace_id ,
249
326
parent_id ,
250
327
sampling_priority ,
251
- ) = extract_context_from_sqs_event_or_context (event , lambda_context )
328
+ ) = extract_context_from_sqs_or_sns_event_or_context (event , lambda_context )
329
+ elif event_source .equals (EventTypes .EVENTBRIDGE ):
330
+ (
331
+ trace_id ,
332
+ parent_id ,
333
+ sampling_priority ,
334
+ ) = extract_context_from_eventbridge_event (event , lambda_context )
335
+ elif event_source .equals (EventTypes .KINESIS ):
336
+ (
337
+ trace_id ,
338
+ parent_id ,
339
+ sampling_priority ,
340
+ ) = extract_context_from_kinesis_event (event , lambda_context )
252
341
else :
253
342
trace_id , parent_id , sampling_priority = extract_context_from_lambda_context (
254
343
lambda_context
@@ -556,6 +645,8 @@ def create_inferred_span_from_http_api_event(event, context):
556
645
557
646
558
647
def create_inferred_span_from_sqs_event (event , context ):
648
+ trace_ctx = tracer .current_trace_context ()
649
+
559
650
event_record = get_first_record (event )
560
651
event_source_arn = event_record ["eventSourceARN" ]
561
652
queue_name = event_source_arn .split (":" )[- 1 ]
@@ -574,11 +665,37 @@ def create_inferred_span_from_sqs_event(event, context):
574
665
"resource" : queue_name ,
575
666
"span_type" : "web" ,
576
667
}
668
+ start_time = int (request_time_epoch ) / 1000
669
+
670
+ # logic to deal with SNS => SQS event
671
+ sns_span = None
672
+ if "body" in event_record :
673
+ body_str = event_record .get ("body" , {})
674
+ try :
675
+ body = json .loads (body_str )
676
+ if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
677
+ logger .debug ("Found SNS message inside SQS event" )
678
+ sns_span = create_inferred_span_from_sns_event (
679
+ create_sns_event (body ), context
680
+ )
681
+ sns_span .finish (finish_time = start_time )
682
+ except Exception as e :
683
+ logger .debug (
684
+ "Unable to create SNS span from SQS message, with error %s" % e
685
+ )
686
+ pass
687
+
688
+ # trace context needs to be set again as it is reset
689
+ # when sns_span.finish executes
690
+ tracer .context_provider .activate (trace_ctx )
577
691
tracer .set_tags ({"_dd.origin" : "lambda" })
578
692
span = tracer .trace ("aws.sqs" , ** args )
579
693
if span :
580
694
span .set_tags (tags )
581
- span .start = int (request_time_epoch ) / 1000
695
+ span .start = start_time
696
+ if sns_span :
697
+ span .parent_id = sns_span .span_id
698
+
582
699
return span
583
700
584
701
@@ -594,9 +711,12 @@ def create_inferred_span_from_sns_event(event, context):
594
711
"topic_arn" : topic_arn ,
595
712
"message_id" : sns_message ["MessageId" ],
596
713
"type" : sns_message ["Type" ],
597
- "subject" : sns_message ["Subject" ],
598
- "event_subscription_arn" : event_record ["EventSubscriptionArn" ],
599
714
}
715
+
716
+ # Subject not available in SNS => SQS scenario
717
+ if "Subject" in sns_message and sns_message ["Subject" ]:
718
+ tags ["subject" ] = sns_message ["Subject" ]
719
+
600
720
InferredSpanInfo .set_tags (tags , tag_source = "self" , synchronicity = "async" )
601
721
sns_dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
602
722
timestamp = event_record ["Sns" ]["Timestamp" ]
@@ -644,7 +764,7 @@ def create_inferred_span_from_kinesis_event(event, context):
644
764
span = tracer .trace ("aws.kinesis" , ** args )
645
765
if span :
646
766
span .set_tags (tags )
647
- span .start = int ( request_time_epoch )
767
+ span .start = request_time_epoch
648
768
return span
649
769
650
770
@@ -662,7 +782,7 @@ def create_inferred_span_from_dynamodb_event(event, context):
662
782
"event_name" : event_record ["eventName" ],
663
783
"event_version" : event_record ["eventVersion" ],
664
784
"stream_view_type" : dynamodb_message ["StreamViewType" ],
665
- "size_bytes" : dynamodb_message ["SizeBytes" ],
785
+ "size_bytes" : str ( dynamodb_message ["SizeBytes" ]) ,
666
786
}
667
787
InferredSpanInfo .set_tags (tags , synchronicity = "async" , tag_source = "self" )
668
788
request_time_epoch = event_record ["dynamodb" ]["ApproximateCreationDateTime" ]
@@ -690,8 +810,8 @@ def create_inferred_span_from_s3_event(event, context):
690
810
"bucketname" : bucket_name ,
691
811
"bucket_arn" : event_record ["s3" ]["bucket" ]["arn" ],
692
812
"object_key" : event_record ["s3" ]["object" ]["key" ],
693
- "object_size" : event_record ["s3" ]["object" ]["size" ],
694
- "object_etag" : event_record ["s3" ]["etag " ],
813
+ "object_size" : str ( event_record ["s3" ]["object" ]["size" ]) ,
814
+ "object_etag" : event_record ["s3" ]["object" ][ "eTag " ],
695
815
}
696
816
InferredSpanInfo .set_tags (tags , synchronicity = "async" , tag_source = "self" )
697
817
dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
@@ -786,7 +906,7 @@ def create_function_execution_span(
786
906
787
907
788
908
class InferredSpanInfo (object ):
789
- BASE_NAME = "inferred_span "
909
+ BASE_NAME = "_inferred_span "
790
910
SYNCHRONICITY = f"{ BASE_NAME } .synchronicity"
791
911
TAG_SOURCE = f"{ BASE_NAME } .tag_source"
792
912
0 commit comments