Skip to content

Parametrized trace context extraction tests. #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 161 additions & 129 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,167 @@ def _wrap(*args, **kwargs):
return _wrapper


_test_extract_dd_trace_context = (
("api-gateway", Context(trace_id=12345, span_id=67890, sampling_priority=2)),
(
"api-gateway-no-apiid",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"api-gateway-non-proxy",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"api-gateway-non-proxy-async",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"api-gateway-websocket-connect",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"api-gateway-websocket-default",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"api-gateway-websocket-disconnect",
Context(trace_id=12345, span_id=67890, sampling_priority=2),
),
(
"authorizer-request-api-gateway-v1",
Context(
trace_id=13478705995797221209,
span_id=8471288263384216896,
sampling_priority=1,
),
),
("authorizer-request-api-gateway-v1-cached", None),
(
"authorizer-request-api-gateway-v2",
Context(
trace_id=14356983619852933354,
span_id=12658621083505413809,
sampling_priority=1,
),
),
("authorizer-request-api-gateway-v2-cached", None),
(
"authorizer-request-api-gateway-websocket-connect",
Context(
trace_id=5351047404834723189,
span_id=18230460631156161837,
sampling_priority=1,
),
),
("authorizer-request-api-gateway-websocket-message", None),
(
"authorizer-token-api-gateway-v1",
Context(
trace_id=17874798268144902712,
span_id=16184667399315372101,
sampling_priority=1,
),
),
("authorizer-token-api-gateway-v1-cached", None),
("cloudfront", None),
("cloudwatch-events", None),
("cloudwatch-logs", None),
("custom", None),
("dynamodb", None),
("eventbridge-custom", Context(trace_id=12345, span_id=67890, sampling_priority=2)),
(
"eventbridge-sqs",
Context(
trace_id=7379586022458917877,
span_id=2644033662113726488,
sampling_priority=1,
),
),
("http-api", Context(trace_id=12345, span_id=67890, sampling_priority=2)),
(
"kinesis",
Context(
trace_id=4948377316357291421,
span_id=2876253380018681026,
sampling_priority=1,
),
),
(
"kinesis-batch",
Context(
trace_id=4948377316357291421,
span_id=2876253380018681026,
sampling_priority=1,
),
),
("lambda-url", None),
("s3", None),
(
"sns-b64-msg-attribute",
Context(
trace_id=4948377316357291421,
span_id=6746998015037429512,
sampling_priority=1,
),
),
(
"sns-batch",
Context(
trace_id=4948377316357291421,
span_id=6746998015037429512,
sampling_priority=1,
),
),
(
"sns-string-msg-attribute",
Context(
trace_id=4948377316357291421,
span_id=6746998015037429512,
sampling_priority=1,
),
),
(
"sqs-batch",
Context(
trace_id=2684756524522091840,
span_id=7431398482019833808,
sampling_priority=1,
),
),
(
"sqs-java-upstream",
Context(
trace_id=7925498337868555493,
span_id=5245570649555658903,
sampling_priority=1,
),
),
(
"sqs-string-msg-attribute",
Context(
trace_id=2684756524522091840,
span_id=7431398482019833808,
sampling_priority=1,
),
),
({"headers": None}, None),
)


@pytest.mark.parametrize("event,expect", _test_extract_dd_trace_context)
def test_extract_dd_trace_context(event, expect):
if isinstance(event, str):
with open(f"{event_samples}{event}.json") as f:
event = json.load(f)
ctx = get_mock_context()

actual, _, _ = extract_dd_trace_context(event, ctx)
assert (expect is None) is (actual is None)
assert (expect is None) or actual.trace_id == expect.trace_id
assert (expect is None) or actual.span_id == expect.span_id
assert (expect is None) or actual.sampling_priority == expect.sampling_priority


class TestExtractAndGetDDTraceContext(unittest.TestCase):
def setUp(self):
global dd_tracing_enabled
Expand Down Expand Up @@ -1773,127 +1934,6 @@ def test_create_inferred_span(mock_span_finish, source, expect):


class TestInferredSpans(unittest.TestCase):
def test_extract_context_from_eventbridge_event(self):
event_sample_source = "eventbridge-custom"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_type = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 12345)
self.assertEqual(context.span_id, 67890),
self.assertEqual(context.sampling_priority, 2)

def test_extract_dd_trace_context_for_eventbridge(self):
event_sample_source = "eventbridge-custom"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_type = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 12345)
self.assertEqual(context.span_id, 67890)

def test_extract_context_from_eventbridge_sqs_event(self):
event_sample_source = "eventbridge-sqs"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)

ctx = get_mock_context()
context, source, event_type = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 7379586022458917877)
self.assertEqual(context.span_id, 2644033662113726488)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sqs_event_with_string_msg_attr(self):
event_sample_source = "sqs-string-msg-attribute"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_type = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 2684756524522091840)
self.assertEqual(context.span_id, 7431398482019833808)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sqs_batch_event(self):
event_sample_source = "sqs-batch"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 2684756524522091840)
self.assertEqual(context.span_id, 7431398482019833808)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sqs_java_upstream_event(self):
event_sample_source = "sqs-java-upstream"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_type = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 7925498337868555493)
self.assertEqual(context.span_id, 5245570649555658903)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sns_event_with_string_msg_attr(self):
event_sample_source = "sns-string-msg-attribute"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 4948377316357291421)
self.assertEqual(context.span_id, 6746998015037429512)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sns_event_with_b64_msg_attr(self):
event_sample_source = "sns-b64-msg-attribute"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 4948377316357291421)
self.assertEqual(context.span_id, 6746998015037429512)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_sns_batch_event(self):
event_sample_source = "sns-batch"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 4948377316357291421)
self.assertEqual(context.span_id, 6746998015037429512)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_kinesis_event(self):
event_sample_source = "kinesis"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 4948377316357291421)
self.assertEqual(context.span_id, 2876253380018681026)
self.assertEqual(context.sampling_priority, 1)

def test_extract_context_from_kinesis_batch_event(self):
event_sample_source = "kinesis-batch"
test_file = event_samples + event_sample_source + ".json"
with open(test_file, "r") as event:
event = json.load(event)
ctx = get_mock_context()
context, source, event_source = extract_dd_trace_context(event, ctx)
self.assertEqual(context.trace_id, 4948377316357291421)
self.assertEqual(context.span_id, 2876253380018681026)
self.assertEqual(context.sampling_priority, 1)

@patch("datadog_lambda.tracing.submit_errors_metric")
def test_mark_trace_as_error_for_5xx_responses_getting_400_response_code(
self, mock_submit_errors_metric
Expand All @@ -1915,14 +1955,6 @@ def test_mark_trace_as_error_for_5xx_responses_sends_error_metric_and_set_error_
mock_submit_errors_metric.assert_called_once()
self.assertEqual(1, mock_span.error)

def test_no_error_with_nonetype_headers(self):
lambda_ctx = get_mock_context()
ctx, source, event_type = extract_dd_trace_context(
{"headers": None},
lambda_ctx,
)
self.assertEqual(ctx, None)


class TestStepFunctionsTraceContext(unittest.TestCase):
def test_deterministic_m5_hash(self):
Expand Down