diff --git a/datadog_lambda/dsm.py b/datadog_lambda/dsm.py new file mode 100644 index 00000000..427f5e47 --- /dev/null +++ b/datadog_lambda/dsm.py @@ -0,0 +1,38 @@ +from datadog_lambda import logger +from datadog_lambda.trigger import EventTypes + + +def set_dsm_context(event, event_source): + + if event_source.equals(EventTypes.SQS): + _dsm_set_sqs_context(event) + + +def _dsm_set_sqs_context(event): + from datadog_lambda.wrapper import format_err_with_traceback + from ddtrace.internal.datastreams import data_streams_processor + from ddtrace.internal.datastreams.processor import DsmPathwayCodec + from ddtrace.internal.datastreams.botocore import ( + get_datastreams_context, + calculate_sqs_payload_size, + ) + + records = event.get("Records") + if records is None: + return + processor = data_streams_processor() + + for record in records: + try: + queue_arn = record.get("eventSourceARN", "") + + contextjson = get_datastreams_context(record) + payload_size = calculate_sqs_payload_size(record) + + ctx = DsmPathwayCodec.decode(contextjson, processor) + ctx.set_checkpoint( + ["direction:in", f"topic:{queue_arn}", "type:sqs"], + payload_size=payload_size, + ) + except Exception as e: + logger.error(format_err_with_traceback(e)) diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index 86bbf04d..0e23b721 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -9,6 +9,7 @@ from importlib import import_module from time import time_ns +from datadog_lambda.dsm import set_dsm_context from datadog_lambda.extension import should_use_extension, flush_extension from datadog_lambda.cold_start import ( set_cold_start, @@ -79,6 +80,7 @@ DD_REQUESTS_SERVICE_NAME = "DD_REQUESTS_SERVICE_NAME" DD_SERVICE = "DD_SERVICE" DD_ENV = "DD_ENV" +DD_DATA_STREAMS_ENABLED = "DD_DATA_STREAMS_ENABLED" def get_env_as_int(env_key, default_value: int) -> int: @@ -190,6 +192,9 @@ def __init__(self, func): self.min_cold_start_trace_duration = get_env_as_int( DD_MIN_COLD_START_DURATION, 3 ) + self.data_streams_enabled = ( + os.environ.get(DD_DATA_STREAMS_ENABLED, "false").lower() == "true" + ) self.local_testing_mode = os.environ.get( DD_LOCAL_TEST, "false" ).lower() in ("true", "1") @@ -322,6 +327,8 @@ def _before(self, event, context): self.inferred_span = create_inferred_span( event, context, event_source, self.decode_authorizer_context ) + if self.data_streams_enabled: + set_dsm_context(event, event_source) self.span = create_function_execution_span( context=context, function_name=self.function_name, diff --git a/tests/test_dsm.py b/tests/test_dsm.py new file mode 100644 index 00000000..544212d8 --- /dev/null +++ b/tests/test_dsm.py @@ -0,0 +1,112 @@ +import unittest +from unittest.mock import patch, MagicMock + +from datadog_lambda.dsm import set_dsm_context, _dsm_set_sqs_context +from datadog_lambda.trigger import EventTypes, _EventSource + + +class TestDsmSQSContext(unittest.TestCase): + def setUp(self): + patcher = patch("datadog_lambda.dsm._dsm_set_sqs_context") + self.mock_dsm_set_sqs_context = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch("ddtrace.internal.datastreams.data_streams_processor") + self.mock_data_streams_processor = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch("ddtrace.internal.datastreams.botocore.get_datastreams_context") + self.mock_get_datastreams_context = patcher.start() + self.mock_get_datastreams_context.return_value = {} + self.addCleanup(patcher.stop) + + patcher = patch( + "ddtrace.internal.datastreams.botocore.calculate_sqs_payload_size" + ) + self.mock_calculate_sqs_payload_size = patcher.start() + self.mock_calculate_sqs_payload_size.return_value = 100 + self.addCleanup(patcher.stop) + + patcher = patch("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode") + self.mock_dsm_pathway_codec_decode = patcher.start() + self.addCleanup(patcher.stop) + + def test_non_sqs_event_source_does_nothing(self): + """Test that non-SQS event sources don't trigger DSM context setting""" + event = {} + # Use Unknown Event Source + event_source = _EventSource(EventTypes.UNKNOWN) + set_dsm_context(event, event_source) + + # DSM context should not be set for non-SQS events + self.mock_dsm_set_sqs_context.assert_not_called() + + def test_sqs_event_with_no_records_does_nothing(self): + """Test that events where Records is None don't trigger DSM processing""" + events_with_no_records = [ + {}, + {"Records": None}, + {"someOtherField": "value"}, + ] + + for event in events_with_no_records: + _dsm_set_sqs_context(event) + self.mock_data_streams_processor.assert_not_called() + + def test_sqs_event_triggers_dsm_sqs_context(self): + """Test that SQS event sources trigger the SQS-specific DSM context function""" + sqs_event = { + "Records": [ + { + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:my-queue", + "body": "Hello from SQS!", + } + ] + } + + event_source = _EventSource(EventTypes.SQS) + set_dsm_context(sqs_event, event_source) + + self.mock_dsm_set_sqs_context.assert_called_once_with(sqs_event) + + def test_sqs_multiple_records_process_each_record(self): + """Test that each record in an SQS event gets processed individually""" + multi_record_event = { + "Records": [ + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue1", + "body": "Message 1", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue2", + "body": "Message 2", + }, + { + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:queue3", + "body": "Message 3", + }, + ] + } + + mock_context = MagicMock() + self.mock_dsm_pathway_codec_decode.return_value = mock_context + + _dsm_set_sqs_context(multi_record_event) + + self.assertEqual(mock_context.set_checkpoint.call_count, 3) + + calls = mock_context.set_checkpoint.call_args_list + expected_arns = [ + "arn:aws:sqs:us-east-1:123456789012:queue1", + "arn:aws:sqs:us-east-1:123456789012:queue2", + "arn:aws:sqs:us-east-1:123456789012:queue3", + ] + + for i, call in enumerate(calls): + args, kwargs = call + tags = args[0] + self.assertIn("direction:in", tags) + self.assertIn(f"topic:{expected_arns[i]}", tags) + self.assertIn("type:sqs", tags) + self.assertEqual(kwargs["payload_size"], 100) diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index f46b365e..f482fa3d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -76,6 +76,10 @@ def setUp(self): self.mock_dd_lambda_layer_tag = patcher.start() self.addCleanup(patcher.stop) + patcher = patch("datadog_lambda.wrapper.set_dsm_context") + self.mock_set_dsm_context = patcher.start() + self.addCleanup(patcher.stop) + def test_datadog_lambda_wrapper(self): wrapper.dd_tracing_enabled = False @@ -563,6 +567,62 @@ def return_type_test(event, context): self.assertEqual(result, test_result) self.assertFalse(MockPrintExc.called) + def test_set_dsm_context_called_when_DSM_and_tracing_enabled(self): + os.environ["DD_DATA_STREAMS_ENABLED"] = "true" + wrapper.dd_tracing_enabled = True + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + return "ok" + + result = lambda_handler({}, get_mock_context()) + self.assertEqual(result, "ok") + self.mock_set_dsm_context.assert_called_once() + + del os.environ["DD_DATA_STREAMS_ENABLED"] + + def test_set_dsm_context_not_called_when_only_DSM_enabled(self): + os.environ["DD_DATA_STREAMS_ENABLED"] = "true" + wrapper.dd_tracing_enabled = False + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + return "ok" + + result = lambda_handler({}, get_mock_context()) + self.assertEqual(result, "ok") + self.mock_set_dsm_context.assert_not_called() + + del os.environ["DD_DATA_STREAMS_ENABLED"] + + def test_set_dsm_context_not_called_when_only_tracing_enabled(self): + os.environ["DD_DATA_STREAMS_ENABLED"] = "false" + wrapper.dd_tracing_enabled = True + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + return "ok" + + result = lambda_handler({}, get_mock_context()) + self.assertEqual(result, "ok") + self.mock_set_dsm_context.assert_not_called() + + del os.environ["DD_DATA_STREAMS_ENABLED"] + + def test_set_dsm_context_not_called_when_tracing_and_DSM_disabled(self): + os.environ["DD_DATA_STREAMS_ENABLED"] = "false" + wrapper.dd_tracing_enabled = False + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + return "ok" + + result = lambda_handler({}, get_mock_context()) + self.assertEqual(result, "ok") + self.mock_set_dsm_context.assert_not_called() + + del os.environ["DD_DATA_STREAMS_ENABLED"] + class TestLambdaDecoratorSettings(unittest.TestCase): def test_some_envs_should_depend_on_dd_tracing_enabled(self):