diff --git a/sentry_sdk/integrations/opentelemetry/consts.py b/sentry_sdk/integrations/opentelemetry/consts.py index 790dac15ec..cb088f13a5 100644 --- a/sentry_sdk/integrations/opentelemetry/consts.py +++ b/sentry_sdk/integrations/opentelemetry/consts.py @@ -11,6 +11,8 @@ SENTRY_USE_CURRENT_SCOPE_KEY = create_key("sentry_use_current_scope") SENTRY_USE_ISOLATION_SCOPE_KEY = create_key("sentry_use_isolation_scope") +SENTRY_TRACE_STATE_DROPPED = "sentry_dropped" + OTEL_SENTRY_CONTEXT = "otel" SPAN_ORIGIN = "auto.otel" diff --git a/sentry_sdk/integrations/opentelemetry/integration.py b/sentry_sdk/integrations/opentelemetry/integration.py index 4cd969f0e0..3f71e86f02 100644 --- a/sentry_sdk/integrations/opentelemetry/integration.py +++ b/sentry_sdk/integrations/opentelemetry/integration.py @@ -12,6 +12,7 @@ from sentry_sdk.integrations.opentelemetry.contextvars_context import ( SentryContextVarsRuntimeContext, ) +from sentry_sdk.integrations.opentelemetry.sampler import SentrySampler from sentry_sdk.utils import logger try: @@ -55,7 +56,7 @@ def _setup_sentry_tracing(): opentelemetry.context._RUNTIME_CONTEXT = SentryContextVarsRuntimeContext() - provider = TracerProvider() + provider = TracerProvider(sampler=SentrySampler()) provider.add_span_processor(PotelSentrySpanProcessor()) trace.set_tracer_provider(provider) diff --git a/sentry_sdk/integrations/opentelemetry/sampler.py b/sentry_sdk/integrations/opentelemetry/sampler.py new file mode 100644 index 0000000000..445c2edd02 --- /dev/null +++ b/sentry_sdk/integrations/opentelemetry/sampler.py @@ -0,0 +1,122 @@ +from random import random + +from opentelemetry import trace + +from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision +from opentelemetry.trace.span import TraceState + +import sentry_sdk +from sentry_sdk.integrations.opentelemetry.consts import SENTRY_TRACE_STATE_DROPPED +from sentry_sdk.tracing_utils import has_tracing_enabled +from sentry_sdk.utils import is_valid_sample_rate, logger + +from typing import TYPE_CHECKING, Optional, Sequence + +if TYPE_CHECKING: + from opentelemetry.context import Context + from opentelemetry.trace import Link, SpanKind + from opentelemetry.trace.span import SpanContext + from opentelemetry.util.types import Attributes + + +def get_parent_sampled(parent_context, trace_id): + # type: (Optional[SpanContext], int) -> Optional[bool] + if parent_context is None: + return None + + is_span_context_valid = parent_context is not None and parent_context.is_valid + + # Only inherit sample rate if `traceId` is the same + if is_span_context_valid and parent_context.trace_id == trace_id: + # this is getSamplingDecision in JS + if parent_context.trace_flags.sampled: + return True + + dropped = parent_context.trace_state.get(SENTRY_TRACE_STATE_DROPPED) == "true" + if dropped: + return False + + # TODO-anton: fall back to sampling decision in DSC (for this die DSC needs to be set in the trace_state) + + return None + + +def dropped(parent_context=None): + # type: (Optional[SpanContext]) -> SamplingResult + trace_state = parent_context.trace_state if parent_context is not None else None + updated_trace_context = trace_state or TraceState() + updated_trace_context = updated_trace_context.update( + SENTRY_TRACE_STATE_DROPPED, "true" + ) + return SamplingResult( + Decision.DROP, + trace_state=updated_trace_context, + ) + + +class SentrySampler(Sampler): + def should_sample( + self, + parent_context, # type: Optional[Context] + trace_id, # type: int + name, # type: str + kind=None, # type: Optional[SpanKind] + attributes=None, # type: Attributes + links=None, # type: Optional[Sequence[Link]] + trace_state=None, # type: Optional[TraceState] + ): + # type: (...) -> SamplingResult + client = sentry_sdk.get_client() + + parent_span = trace.get_current_span(parent_context) + parent_context = parent_span.get_span_context() if parent_span else None + + # No tracing enabled, thus no sampling + if not has_tracing_enabled(client.options): + return dropped(parent_context) + + sample_rate = None + + # Check if sampled=True was passed to start_transaction + # TODO-anton: Do we want to keep the start_transaction(sampled=True) thing? + + # Check if there is a traces_sampler + # Traces_sampler is responsible to check parent sampled to have full transactions. + has_traces_sampler = callable(client.options.get("traces_sampler")) + if has_traces_sampler: + # TODO-anton: Make proper sampling_context + sampling_context = { + "transaction_context": { + "name": name, + }, + "parent_sampled": get_parent_sampled(parent_context, trace_id), + } + + sample_rate = client.options["traces_sampler"](sampling_context) + + else: + # Check if there is a parent with a sampling decision + parent_sampled = get_parent_sampled(parent_context, trace_id) + if parent_sampled is not None: + sample_rate = parent_sampled + else: + # Check if there is a traces_sample_rate + sample_rate = client.options.get("traces_sample_rate") + + # If the sample rate is invalid, drop the span + if not is_valid_sample_rate(sample_rate, source=self.__class__.__name__): + logger.warning( + f"[Tracing] Discarding {name} because of invalid sample rate." + ) + return dropped(parent_context) + + # Roll the dice on sample rate + sampled = random() < float(sample_rate) + + if sampled: + return SamplingResult(Decision.RECORD_AND_SAMPLE) + else: + return dropped(parent_context) + + def get_description(self) -> str: + return self.__class__.__name__ diff --git a/sentry_sdk/integrations/opentelemetry/scope.py b/sentry_sdk/integrations/opentelemetry/scope.py index 48782875ed..40eb47c01f 100644 --- a/sentry_sdk/integrations/opentelemetry/scope.py +++ b/sentry_sdk/integrations/opentelemetry/scope.py @@ -97,6 +97,9 @@ def _incoming_otel_span_context(self): span_id=int(self._propagation_context.parent_span_id, 16), # type: ignore is_remote=True, trace_flags=trace_flags, + # TODO-anton: add trace_state (mapping[str,str]) with the parentSpanId, dsc and sampled from self._propagation_context + # trace_state={ + # } ) return span_context diff --git a/tests/integrations/opentelemetry/test_sampler.py b/tests/integrations/opentelemetry/test_sampler.py new file mode 100644 index 0000000000..dfd4981ecf --- /dev/null +++ b/tests/integrations/opentelemetry/test_sampler.py @@ -0,0 +1,339 @@ +import pytest +from unittest import mock + +from opentelemetry import trace + +import sentry_sdk + + +tracer = trace.get_tracer(__name__) + + +@pytest.mark.parametrize( + "traces_sample_rate, expected_num_of_envelopes", + [ + # special case for testing, do not pass any traces_sample_rate to init() (the default traces_sample_rate=None will be used) + (-1, 0), + # traces_sample_rate=None means do not create new traces, and also do not continue incoming traces. So, no envelopes at all. + (None, 0), + # traces_sample_rate=0 means do not create new traces (0% of the requests), but continue incoming traces. So envelopes will be created only if there is an incoming trace. + (0, 0), + # traces_sample_rate=1 means create new traces for 100% of requests (and also continue incoming traces, of course). + (1, 2), + ], +) +def test_sampling_traces_sample_rate_0_or_100( + sentry_init, + capture_envelopes, + traces_sample_rate, + expected_num_of_envelopes, +): + kwargs = {} + if traces_sample_rate != -1: + kwargs["traces_sample_rate"] = traces_sample_rate + + sentry_init(**kwargs) + + envelopes = capture_envelopes() + + with sentry_sdk.start_span(description="request a"): + with sentry_sdk.start_span(description="cache a"): + with sentry_sdk.start_span(description="db a"): + ... + + with sentry_sdk.start_span(description="request b"): + with sentry_sdk.start_span(description="cache b"): + with sentry_sdk.start_span(description="db b"): + ... + + assert len(envelopes) == expected_num_of_envelopes + + if expected_num_of_envelopes == 2: + (transaction_a, transaction_b) = [ + envelope.items[0].payload.json for envelope in envelopes + ] + + assert transaction_a["transaction"] == "request a" + assert transaction_b["transaction"] == "request b" + + spans_a = transaction_a["spans"] + assert len(spans_a) == 2 + assert spans_a[0]["description"] == "cache a" + assert spans_a[1]["description"] == "db a" + spans_b = transaction_b["spans"] + assert len(spans_b) == 2 + assert spans_b[0]["description"] == "cache b" + assert spans_b[1]["description"] == "db b" + + +def test_sampling_traces_sample_rate_50(sentry_init, capture_envelopes): + sentry_init(traces_sample_rate=0.5) + + envelopes = capture_envelopes() + + with mock.patch( + "sentry_sdk.integrations.opentelemetry.sampler.random", return_value=0.2 + ): # drop + with sentry_sdk.start_span(description="request a"): + with sentry_sdk.start_span(description="cache a"): + with sentry_sdk.start_span(description="db a"): + ... + + with mock.patch( + "sentry_sdk.integrations.opentelemetry.sampler.random", return_value=0.7 + ): # keep + with sentry_sdk.start_span(description="request b"): + with sentry_sdk.start_span(description="cache b"): + with sentry_sdk.start_span(description="db b"): + ... + + assert len(envelopes) == 1 + + (envelope,) = envelopes + transaction = envelope.items[0].payload.json + assert transaction["transaction"] == "request a" + spans = transaction["spans"] + assert len(spans) == 2 + assert spans[0]["description"] == "cache a" + assert spans[1]["description"] == "db a" + + +def test_sampling_traces_sampler(sentry_init, capture_envelopes): + def keep_only_a(sampling_context): + if " a" in sampling_context["transaction_context"]["name"]: + return 0.05 + else: + return 0 + + sentry_init( + traces_sample_rate=1.0, + traces_sampler=keep_only_a, + ) + + envelopes = capture_envelopes() + + # Make sure random() always returns the same values + with mock.patch( + "sentry_sdk.integrations.opentelemetry.sampler.random", + side_effect=[0.04 for _ in range(12)], + ): + + with sentry_sdk.start_span(description="request a"): # keep + with sentry_sdk.start_span(description="cache a"): # keep + with sentry_sdk.start_span(description="db a"): # keep + ... + + with sentry_sdk.start_span(description="request b"): # drop + with sentry_sdk.start_span(description="cache b"): # drop + with sentry_sdk.start_span(description="db b"): # drop + ... + + with sentry_sdk.start_span(description="request c"): # drop + with sentry_sdk.start_span( + description="cache a c" + ): # keep (but trx dropped, so not collected) + with sentry_sdk.start_span( + description="db a c" + ): # keep (but trx dropped, so not collected) + ... + + with sentry_sdk.start_span(description="new a c"): # keep + with sentry_sdk.start_span(description="cache c"): # drop + with sentry_sdk.start_span(description="db c"): # drop + ... + + assert len(envelopes) == 2 + (envelope1, envelope2) = envelopes + transaction1 = envelope1.items[0].payload.json + transaction2 = envelope2.items[0].payload.json + + assert transaction1["transaction"] == "request a" + assert len(transaction1["spans"]) == 2 + assert transaction2["transaction"] == "new a c" + assert len(transaction2["spans"]) == 0 + + +def test_sampling_traces_sampler_boolean(sentry_init, capture_envelopes): + def keep_only_a(sampling_context): + if " a" in sampling_context["transaction_context"]["name"]: + return True + else: + return False + + sentry_init( + traces_sample_rate=1.0, + traces_sampler=keep_only_a, + ) + + envelopes = capture_envelopes() + + with sentry_sdk.start_span(description="request a"): # keep + with sentry_sdk.start_span(description="cache a"): # keep + with sentry_sdk.start_span(description="db X"): # drop + ... + + with sentry_sdk.start_span(description="request b"): # drop + with sentry_sdk.start_span(description="cache b"): # drop + with sentry_sdk.start_span(description="db b"): # drop + ... + + assert len(envelopes) == 1 + (envelope,) = envelopes + transaction = envelope.items[0].payload.json + + assert transaction["transaction"] == "request a" + assert len(transaction["spans"]) == 1 + + +@pytest.mark.parametrize( + "traces_sample_rate, expected_num_of_envelopes", + [ + # special case for testing, do not pass any traces_sample_rate to init() (the default traces_sample_rate=None will be used) + (-1, 0), + # traces_sample_rate=None means do not create new traces, and also do not continue incoming traces. So, no envelopes at all. + (None, 0), + # traces_sample_rate=0 means do not create new traces (0% of the requests), but continue incoming traces. So envelopes will be created only if there is an incoming trace. + (0, 1), + # traces_sample_rate=1 means create new traces for 100% of requests (and also continue incoming traces, of course). + (1, 1), + ], +) +def test_sampling_parent_sampled( + sentry_init, + traces_sample_rate, + expected_num_of_envelopes, + capture_envelopes, +): + kwargs = {} + if traces_sample_rate != -1: + kwargs["traces_sample_rate"] = traces_sample_rate + + sentry_init(**kwargs) + + envelopes = capture_envelopes() + + # The upstream service has sampled the request + headers = { + "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-1", + } + with sentry_sdk.continue_trace(headers): + with sentry_sdk.start_span(description="request a"): + with sentry_sdk.start_span(description="cache a"): + with sentry_sdk.start_span(description="db X"): + ... + + assert len(envelopes) == expected_num_of_envelopes + + if expected_num_of_envelopes == 1: + (envelope,) = envelopes + transaction = envelope.items[0].payload.json + assert transaction["transaction"] == "request a" + assert ( + transaction["contexts"]["trace"]["trace_id"] + == "771a43a4192642f0b136d5159a501700" + ) + assert transaction["contexts"]["trace"]["span_id"] != "1234567890abcdef" + assert transaction["contexts"]["trace"]["parent_span_id"] == "1234567890abcdef" + + +@pytest.mark.parametrize( + "traces_sample_rate, expected_num_of_envelopes", + [ + # special case for testing, do not pass any traces_sample_rate to init() (the default traces_sample_rate=None will be used) + (-1, 0), + # traces_sample_rate=None means do not create new traces, and also do not continue incoming traces. So, no envelopes at all. + (None, 0), + # traces_sample_rate=0 means do not create new traces (0% of the requests), but continue incoming traces. So envelopes will be created only if there is an incoming trace. + (0, 0), + # traces_sample_rate=1 means create new traces for 100% of requests (and also continue incoming traces, of course). + (1, 1), + ], +) +def test_sampling_parent_dropped( + sentry_init, + traces_sample_rate, + expected_num_of_envelopes, + capture_envelopes, +): + kwargs = {} + if traces_sample_rate != -1: + kwargs["traces_sample_rate"] = traces_sample_rate + + sentry_init(**kwargs) + + envelopes = capture_envelopes() + + # The upstream service has dropped the request + headers = { + "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-0", + } + with sentry_sdk.continue_trace(headers): + with sentry_sdk.start_span(description="request a"): + with sentry_sdk.start_span(description="cache a"): + with sentry_sdk.start_span(description="db X"): + ... + + assert len(envelopes) == expected_num_of_envelopes + + if expected_num_of_envelopes == 1: + (envelope,) = envelopes + transaction = envelope.items[0].payload.json + assert transaction["transaction"] == "request a" + assert ( + transaction["contexts"]["trace"]["trace_id"] + == "771a43a4192642f0b136d5159a501700" + ) + assert transaction["contexts"]["trace"]["span_id"] != "1234567890abcdef" + assert transaction["contexts"]["trace"]["parent_span_id"] == "1234567890abcdef" + + +@pytest.mark.parametrize( + "traces_sample_rate, expected_num_of_envelopes", + [ + # special case for testing, do not pass any traces_sample_rate to init() (the default traces_sample_rate=None will be used) + (-1, 0), + # traces_sample_rate=None means do not create new traces, and also do not continue incoming traces. So, no envelopes at all. + (None, 0), + # traces_sample_rate=0 means do not create new traces (0% of the requests), but continue incoming traces. So envelopes will be created only if there is an incoming trace. + (0, 0), + # traces_sample_rate=1 means create new traces for 100% of requests (and also continue incoming traces, of course). + (1, 1), + ], +) +def test_sampling_parent_deferred( + sentry_init, + traces_sample_rate, + expected_num_of_envelopes, + capture_envelopes, +): + kwargs = {} + if traces_sample_rate != -1: + kwargs["traces_sample_rate"] = traces_sample_rate + + sentry_init(**kwargs) + + envelopes = capture_envelopes() + + # The upstream service has deferred the sampling decision to us. + headers = { + "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-", + } + + with sentry_sdk.continue_trace(headers): + with sentry_sdk.start_span(description="request a"): + with sentry_sdk.start_span(description="cache a"): + with sentry_sdk.start_span(description="db X"): + ... + + assert len(envelopes) == expected_num_of_envelopes + + if expected_num_of_envelopes == 1: + (envelope,) = envelopes + transaction = envelope.items[0].payload.json + assert transaction["transaction"] == "request a" + assert ( + transaction["contexts"]["trace"]["trace_id"] + == "771a43a4192642f0b136d5159a501700" + ) + assert transaction["contexts"]["trace"]["span_id"] != "1234567890abcdef" + assert transaction["contexts"]["trace"]["parent_span_id"] == "1234567890abcdef"