Skip to content

Commit 3061f57

Browse files
committed
Port traces_sample_rate update to potel-base
1 parent eb93c1f commit 3061f57

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

sentry_sdk/integrations/opentelemetry/sampler.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,49 @@ def get_parent_sampled(parent_context, trace_id):
5050
return None
5151

5252

53+
def get_parent_sample_rate(parent_context, trace_id):
54+
# type: (Optional[SpanContext], int) -> Optional[float]
55+
if parent_context is None:
56+
return None
57+
58+
is_span_context_valid = parent_context is not None and parent_context.is_valid
59+
60+
if is_span_context_valid and parent_context.trace_id == trace_id:
61+
parent_sample_rate = parent_context.trace_state.get(TRACESTATE_SAMPLE_RATE_KEY)
62+
if parent_sample_rate is None:
63+
return None
64+
65+
try:
66+
return float(parent_sample_rate)
67+
except Exception:
68+
return None
69+
70+
return None
71+
72+
73+
def _update_sample_rate(sample_rate, trace_state):
74+
# type: (float, TraceState) -> TraceState
75+
if TRACESTATE_SAMPLE_RATE_KEY in trace_state:
76+
trace_state = trace_state.update(TRACESTATE_SAMPLE_RATE_KEY, str(sample_rate))
77+
else:
78+
trace_state = trace_state.add(TRACESTATE_SAMPLE_RATE_KEY, str(sample_rate))
79+
80+
return trace_state
81+
82+
5383
def dropped_result(parent_span_context, attributes, sample_rate=None):
5484
# type: (SpanContext, Attributes, Optional[float]) -> SamplingResult
5585
# these will only be added the first time in a root span sampling decision
86+
# if sample_rate is provided, it'll be updated in trace state
5687
trace_state = parent_span_context.trace_state
5788

5889
if TRACESTATE_SAMPLED_KEY not in trace_state:
5990
trace_state = trace_state.add(TRACESTATE_SAMPLED_KEY, "false")
6091
elif trace_state.get(TRACESTATE_SAMPLED_KEY) == "deferred":
6192
trace_state = trace_state.update(TRACESTATE_SAMPLED_KEY, "false")
6293

63-
if sample_rate and TRACESTATE_SAMPLE_RATE_KEY not in trace_state:
64-
trace_state = trace_state.add(TRACESTATE_SAMPLE_RATE_KEY, str(sample_rate))
94+
if sample_rate is not None:
95+
trace_state = _update_sample_rate(sample_rate, trace_state)
6596

6697
is_root_span = not (
6798
parent_span_context.is_valid and not parent_span_context.is_remote
@@ -88,17 +119,18 @@ def dropped_result(parent_span_context, attributes, sample_rate=None):
88119

89120

90121
def sampled_result(span_context, attributes, sample_rate):
91-
# type: (SpanContext, Attributes, float) -> SamplingResult
122+
# type: (SpanContext, Attributes, Optional[float]) -> SamplingResult
92123
# these will only be added the first time in a root span sampling decision
124+
# if sample_rate is provided, it'll be updated in trace state
93125
trace_state = span_context.trace_state
94126

95127
if TRACESTATE_SAMPLED_KEY not in trace_state:
96128
trace_state = trace_state.add(TRACESTATE_SAMPLED_KEY, "true")
97129
elif trace_state.get(TRACESTATE_SAMPLED_KEY) == "deferred":
98130
trace_state = trace_state.update(TRACESTATE_SAMPLED_KEY, "true")
99131

100-
if TRACESTATE_SAMPLE_RATE_KEY not in trace_state:
101-
trace_state = trace_state.add(TRACESTATE_SAMPLE_RATE_KEY, str(sample_rate))
132+
if sample_rate is not None:
133+
trace_state = _update_sample_rate(sample_rate, trace_state)
102134

103135
return SamplingResult(
104136
Decision.RECORD_AND_SAMPLE,
@@ -142,9 +174,13 @@ def should_sample(
142174
if is_root_span:
143175
sample_rate = float(custom_sampled)
144176
if sample_rate > 0:
145-
return sampled_result(parent_span_context, attributes, sample_rate)
177+
return sampled_result(
178+
parent_span_context, attributes, sample_rate=sample_rate
179+
)
146180
else:
147-
return dropped_result(parent_span_context, attributes)
181+
return dropped_result(
182+
parent_span_context, attributes, sample_rate=sample_rate
183+
)
148184
else:
149185
logger.debug(
150186
f"[Tracing] Ignoring sampled param for non-root span {name}"
@@ -154,19 +190,27 @@ def should_sample(
154190
# Traces_sampler is responsible to check parent sampled to have full transactions.
155191
has_traces_sampler = callable(client.options.get("traces_sampler"))
156192

193+
sample_rate_to_propagate = None
194+
157195
if is_root_span and has_traces_sampler:
158196
sampling_context = create_sampling_context(
159197
name, attributes, parent_span_context, trace_id
160198
)
161199
sample_rate = client.options["traces_sampler"](sampling_context)
200+
sample_rate_to_propagate = sample_rate
162201
else:
163202
# Check if there is a parent with a sampling decision
164203
parent_sampled = get_parent_sampled(parent_span_context, trace_id)
204+
parent_sample_rate = get_parent_sample_rate(parent_span_context, trace_id)
165205
if parent_sampled is not None:
166-
sample_rate = parent_sampled
206+
sample_rate = bool(parent_sampled)
207+
sample_rate_to_propagate = (
208+
parent_sample_rate if parent_sample_rate else sample_rate
209+
)
167210
else:
168211
# Check if there is a traces_sample_rate
169212
sample_rate = client.options.get("traces_sample_rate")
213+
sample_rate_to_propagate = sample_rate
170214

171215
# If the sample rate is invalid, drop the span
172216
if not is_valid_sample_rate(sample_rate, source=self.__class__.__name__):
@@ -178,15 +222,21 @@ def should_sample(
178222
# Down-sample in case of back pressure monitor says so
179223
if is_root_span and client.monitor:
180224
sample_rate /= 2**client.monitor.downsample_factor
225+
if client.monitor.downsample_factor > 0:
226+
sample_rate_to_propagate = sample_rate
181227

182228
# Roll the dice on sample rate
183229
sample_rate = float(cast("Union[bool, float, int]", sample_rate))
184230
sampled = random.random() < sample_rate
185231

186232
if sampled:
187-
return sampled_result(parent_span_context, attributes, sample_rate)
233+
return sampled_result(
234+
parent_span_context, attributes, sample_rate=sample_rate_to_propagate
235+
)
188236
else:
189-
return dropped_result(parent_span_context, attributes, sample_rate)
237+
return dropped_result(
238+
parent_span_context, attributes, sample_rate=sample_rate_to_propagate
239+
)
190240

191241
def get_description(self) -> str:
192242
return self.__class__.__name__

tests/test_dsc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_dsc_continuation_of_trace(sentry_init, capture_envelopes):
118118

119119
assert "sample_rate" in envelope_trace_header
120120
assert type(envelope_trace_header["sample_rate"]) == str
121-
assert envelope_trace_header["sample_rate"] == "1.0"
121+
assert envelope_trace_header["sample_rate"] == "0.01337"
122122

123123
assert "sampled" in envelope_trace_header
124124
assert type(envelope_trace_header["sampled"]) == str

0 commit comments

Comments
 (0)