@@ -50,18 +50,49 @@ def get_parent_sampled(parent_context, trace_id):
50
50
return None
51
51
52
52
53
+ def get_parent_sample_rate (parent_context , trace_id ):
54
+ # type: (Optional[SpanContext], int) -> Optional[float]
55
+ if parent_context is None :
56
+ return None
57
+
58
+ is_span_context_valid = parent_context is not None and parent_context .is_valid
59
+
60
+ if is_span_context_valid and parent_context .trace_id == trace_id :
61
+ parent_sample_rate = parent_context .trace_state .get (TRACESTATE_SAMPLE_RATE_KEY )
62
+ if parent_sample_rate is None :
63
+ return None
64
+
65
+ try :
66
+ return float (parent_sample_rate )
67
+ except Exception :
68
+ return None
69
+
70
+ return None
71
+
72
+
73
+ def _update_sample_rate (sample_rate , trace_state ):
74
+ # type: (float, TraceState) -> TraceState
75
+ if TRACESTATE_SAMPLE_RATE_KEY in trace_state :
76
+ trace_state = trace_state .update (TRACESTATE_SAMPLE_RATE_KEY , str (sample_rate ))
77
+ else :
78
+ trace_state = trace_state .add (TRACESTATE_SAMPLE_RATE_KEY , str (sample_rate ))
79
+
80
+ return trace_state
81
+
82
+
53
83
def dropped_result (parent_span_context , attributes , sample_rate = None ):
54
84
# type: (SpanContext, Attributes, Optional[float]) -> SamplingResult
55
85
# these will only be added the first time in a root span sampling decision
86
+ # if sample_rate is provided, it'll be updated in trace state
56
87
trace_state = parent_span_context .trace_state
57
88
58
89
if TRACESTATE_SAMPLED_KEY not in trace_state :
59
90
trace_state = trace_state .add (TRACESTATE_SAMPLED_KEY , "false" )
60
91
elif trace_state .get (TRACESTATE_SAMPLED_KEY ) == "deferred" :
61
92
trace_state = trace_state .update (TRACESTATE_SAMPLED_KEY , "false" )
62
93
63
- if sample_rate and TRACESTATE_SAMPLE_RATE_KEY not in trace_state :
64
- trace_state = trace_state . add ( TRACESTATE_SAMPLE_RATE_KEY , str ( sample_rate ) )
94
+ if sample_rate is not None :
95
+ trace_state = _update_sample_rate ( sample_rate , trace_state )
65
96
66
97
is_root_span = not (
67
98
parent_span_context .is_valid and not parent_span_context .is_remote
@@ -88,17 +119,18 @@ def dropped_result(parent_span_context, attributes, sample_rate=None):
88
119
89
120
90
121
def sampled_result (span_context , attributes , sample_rate ):
91
- # type: (SpanContext, Attributes, float) -> SamplingResult
122
+ # type: (SpanContext, Attributes, Optional[ float] ) -> SamplingResult
92
123
# these will only be added the first time in a root span sampling decision
124
+ # if sample_rate is provided, it'll be updated in trace state
93
125
trace_state = span_context .trace_state
94
126
95
127
if TRACESTATE_SAMPLED_KEY not in trace_state :
96
128
trace_state = trace_state .add (TRACESTATE_SAMPLED_KEY , "true" )
97
129
elif trace_state .get (TRACESTATE_SAMPLED_KEY ) == "deferred" :
98
130
trace_state = trace_state .update (TRACESTATE_SAMPLED_KEY , "true" )
99
131
100
- if TRACESTATE_SAMPLE_RATE_KEY not in trace_state :
101
- trace_state = trace_state . add ( TRACESTATE_SAMPLE_RATE_KEY , str ( sample_rate ) )
132
+ if sample_rate is not None :
133
+ trace_state = _update_sample_rate ( sample_rate , trace_state )
102
134
103
135
return SamplingResult (
104
136
Decision .RECORD_AND_SAMPLE ,
@@ -142,9 +174,13 @@ def should_sample(
142
174
if is_root_span :
143
175
sample_rate = float (custom_sampled )
144
176
if sample_rate > 0 :
145
- return sampled_result (parent_span_context , attributes , sample_rate )
177
+ return sampled_result (
178
+ parent_span_context , attributes , sample_rate = sample_rate
179
+ )
146
180
else :
147
- return dropped_result (parent_span_context , attributes )
181
+ return dropped_result (
182
+ parent_span_context , attributes , sample_rate = sample_rate
183
+ )
148
184
else :
149
185
logger .debug (
150
186
f"[Tracing] Ignoring sampled param for non-root span { name } "
@@ -154,19 +190,27 @@ def should_sample(
154
190
# Traces_sampler is responsible to check parent sampled to have full transactions.
155
191
has_traces_sampler = callable (client .options .get ("traces_sampler" ))
156
192
193
+ sample_rate_to_propagate = None
194
+
157
195
if is_root_span and has_traces_sampler :
158
196
sampling_context = create_sampling_context (
159
197
name , attributes , parent_span_context , trace_id
160
198
)
161
199
sample_rate = client .options ["traces_sampler" ](sampling_context )
200
+ sample_rate_to_propagate = sample_rate
162
201
else :
163
202
# Check if there is a parent with a sampling decision
164
203
parent_sampled = get_parent_sampled (parent_span_context , trace_id )
204
+ parent_sample_rate = get_parent_sample_rate (parent_span_context , trace_id )
165
205
if parent_sampled is not None :
166
- sample_rate = parent_sampled
206
+ sample_rate = bool (parent_sampled )
207
+ sample_rate_to_propagate = (
208
+ parent_sample_rate if parent_sample_rate else sample_rate
209
+ )
167
210
else :
168
211
# Check if there is a traces_sample_rate
169
212
sample_rate = client .options .get ("traces_sample_rate" )
213
+ sample_rate_to_propagate = sample_rate
170
214
171
215
# If the sample rate is invalid, drop the span
172
216
if not is_valid_sample_rate (sample_rate , source = self .__class__ .__name__ ):
@@ -178,15 +222,21 @@ def should_sample(
178
222
# Down-sample in case of back pressure monitor says so
179
223
if is_root_span and client .monitor :
180
224
sample_rate /= 2 ** client .monitor .downsample_factor
225
+ if client .monitor .downsample_factor > 0 :
226
+ sample_rate_to_propagate = sample_rate
181
227
182
228
# Roll the dice on sample rate
183
229
sample_rate = float (cast ("Union[bool, float, int]" , sample_rate ))
184
230
sampled = random .random () < sample_rate
185
231
186
232
if sampled :
187
- return sampled_result (parent_span_context , attributes , sample_rate )
233
+ return sampled_result (
234
+ parent_span_context , attributes , sample_rate = sample_rate_to_propagate
235
+ )
188
236
else :
189
- return dropped_result (parent_span_context , attributes , sample_rate )
237
+ return dropped_result (
238
+ parent_span_context , attributes , sample_rate = sample_rate_to_propagate
239
+ )
190
240
191
241
def get_description (self ) -> str :
192
242
return self .__class__ .__name__
0 commit comments