Skip to content

Commit d1059dd

Browse files
Fix Celery tests in POTel (#3772)
Co-authored-by: Neel Shah <[email protected]>
1 parent e41b24a commit d1059dd

File tree

2 files changed

+61
-37
lines changed

2 files changed

+61
-37
lines changed

sentry_sdk/integrations/celery/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def _capture_exception(task, exc_info):
112112
return
113113

114114
if isinstance(exc_info[1], CELERY_CONTROL_FLOW_EXCEPTIONS):
115-
# ??? Doesn't map to anything
116115
_set_status("aborted")
117116
return
118117

@@ -276,6 +275,7 @@ def apply_async(*args, **kwargs):
276275
op=OP.QUEUE_SUBMIT_CELERY,
277276
name=task_name,
278277
origin=CeleryIntegration.origin,
278+
only_if_parent=True,
279279
)
280280
if not task_started_from_beat
281281
else NoOpMgr()
@@ -306,11 +306,13 @@ def _inner(*args, **kwargs):
306306
with isolation_scope() as scope:
307307
scope._name = "celery"
308308
scope.clear_breadcrumbs()
309+
scope.set_transaction_name(task.name, source=TRANSACTION_SOURCE_TASK)
309310
scope.add_event_processor(_make_event_processor(task, *args, **kwargs))
310311

311312
# Celery task objects are not a thing to be trusted. Even
312313
# something such as attribute access can fail.
313314
headers = args[3].get("headers") or {}
315+
314316
with sentry_sdk.continue_trace(headers):
315317
with sentry_sdk.start_span(
316318
op=OP.QUEUE_TASK_CELERY,
@@ -320,9 +322,13 @@ def _inner(*args, **kwargs):
320322
# for some reason, args[1] is a list if non-empty but a
321323
# tuple if empty
322324
attributes=_prepopulate_attributes(task, list(args[1]), args[2]),
323-
) as transaction:
324-
transaction.set_status(SPANSTATUS.OK)
325-
return f(*args, **kwargs)
325+
) as root_span:
326+
return_value = f(*args, **kwargs)
327+
328+
if root_span.status is None:
329+
root_span.set_status(SPANSTATUS.OK)
330+
331+
return return_value
326332

327333
return _inner # type: ignore
328334

@@ -359,6 +365,7 @@ def _inner(*args, **kwargs):
359365
op=OP.QUEUE_PROCESS,
360366
name=task.name,
361367
origin=CeleryIntegration.origin,
368+
only_if_parent=True,
362369
) as span:
363370
_set_messaging_destination_name(task, span)
364371

@@ -390,6 +397,7 @@ def _inner(*args, **kwargs):
390397
)
391398

392399
return f(*args, **kwargs)
400+
393401
except Exception:
394402
exc_info = sys.exc_info()
395403
with capture_internal_exceptions():

tests/integrations/celery/test_celery.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import pytest
66
from celery import Celery, VERSION
77
from celery.bin import worker
8+
from celery.app.task import Task
9+
from opentelemetry import trace as otel_trace, context
810

911
import sentry_sdk
10-
from sentry_sdk import start_span, get_current_span
12+
from sentry_sdk import get_current_span
1113
from sentry_sdk.integrations.celery import (
1214
CeleryIntegration,
1315
_wrap_task_run,
@@ -126,14 +128,14 @@ def dummy_task(x, y):
126128
foo = 42 # noqa
127129
return x / y
128130

129-
with start_span(op="unit test transaction") as transaction:
131+
with sentry_sdk.start_span(op="unit test transaction") as root_span:
130132
celery_invocation(dummy_task, 1, 2)
131133
_, expected_context = celery_invocation(dummy_task, 1, 0)
132134

133135
(_, error_event, _, _) = events
134136

135-
assert error_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
136-
assert error_event["contexts"]["trace"]["span_id"] != transaction.span_id
137+
assert error_event["contexts"]["trace"]["trace_id"] == root_span.trace_id
138+
assert error_event["contexts"]["trace"]["span_id"] != root_span.span_id
137139
assert error_event["transaction"] == "dummy_task"
138140
assert "celery_task_id" in error_event["tags"]
139141
assert error_event["extra"]["celery-job"] == dict(
@@ -190,17 +192,14 @@ def test_transaction_events(capture_events, init_celery, celery_invocation, task
190192
def dummy_task(x, y):
191193
return x / y
192194

193-
# XXX: For some reason the first call does not get instrumented properly.
194-
celery_invocation(dummy_task, 1, 1)
195-
196195
events = capture_events()
197196

198-
with start_span(name="submission") as transaction:
197+
with sentry_sdk.start_span(name="submission") as root_span:
199198
celery_invocation(dummy_task, 1, 0 if task_fails else 1)
200199

201200
if task_fails:
202201
error_event = events.pop(0)
203-
assert error_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
202+
assert error_event["contexts"]["trace"]["trace_id"] == root_span.trace_id
204203
assert error_event["exception"]["values"][0]["type"] == "ZeroDivisionError"
205204

206205
execution_event, submission_event = events
@@ -211,24 +210,21 @@ def dummy_task(x, y):
211210
assert submission_event["transaction_info"] == {"source": "custom"}
212211

213212
assert execution_event["type"] == submission_event["type"] == "transaction"
214-
assert execution_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
215-
assert submission_event["contexts"]["trace"]["trace_id"] == transaction.trace_id
213+
assert execution_event["contexts"]["trace"]["trace_id"] == root_span.trace_id
214+
assert submission_event["contexts"]["trace"]["trace_id"] == root_span.trace_id
216215

217216
if task_fails:
218217
assert execution_event["contexts"]["trace"]["status"] == "internal_error"
219218
else:
220219
assert execution_event["contexts"]["trace"]["status"] == "ok"
221220

222221
assert len(execution_event["spans"]) == 1
223-
assert (
224-
execution_event["spans"][0].items()
225-
>= {
226-
"trace_id": str(transaction.trace_id),
227-
"same_process_as_parent": True,
222+
assert execution_event["spans"][0] == ApproxDict(
223+
{
224+
"trace_id": str(root_span.trace_id),
228225
"op": "queue.process",
229226
"description": "dummy_task",
230-
"data": ApproxDict(),
231-
}.items()
227+
}
232228
)
233229
assert submission_event["spans"] == [
234230
{
@@ -237,11 +233,14 @@ def dummy_task(x, y):
237233
"op": "queue.submit.celery",
238234
"origin": "auto.queue.celery",
239235
"parent_span_id": submission_event["contexts"]["trace"]["span_id"],
240-
"same_process_as_parent": True,
241236
"span_id": submission_event["spans"][0]["span_id"],
242237
"start_timestamp": submission_event["spans"][0]["start_timestamp"],
243238
"timestamp": submission_event["spans"][0]["timestamp"],
244-
"trace_id": str(transaction.trace_id),
239+
"trace_id": str(root_span.trace_id),
240+
"status": "ok",
241+
"tags": {
242+
"status": "ok",
243+
},
245244
}
246245
]
247246

@@ -275,7 +274,7 @@ def test_simple_no_propagation(capture_events, init_celery):
275274
def dummy_task():
276275
1 / 0
277276

278-
with start_span(name="task") as root_span:
277+
with sentry_sdk.start_span(name="task") as root_span:
279278
dummy_task.delay()
280279

281280
(event,) = events
@@ -350,7 +349,7 @@ def dummy_task(self):
350349
runs.append(1)
351350
1 / 0
352351

353-
with start_span(name="submit_celery"):
352+
with sentry_sdk.start_span(name="submit_celery"):
354353
# Curious: Cannot use delay() here or py2.7-celery-4.2 crashes
355354
res = dummy_task.apply_async()
356355

@@ -445,7 +444,7 @@ def walk_dogs(x, y):
445444
walk_dogs, [["Maisey", "Charlie", "Bodhi", "Cory"], "Dog park round trip"], 1
446445
)
447446

448-
sampling_context = traces_sampler.call_args_list[1][0][0]
447+
sampling_context = traces_sampler.call_args_list[0][0][0]
449448
assert sampling_context["celery.job.task"] == "dog_walk"
450449
for i, arg in enumerate(args_kwargs["args"]):
451450
assert sampling_context[f"celery.job.args.{i}"] == str(arg)
@@ -469,7 +468,7 @@ def __call__(self, *args, **kwargs):
469468
def dummy_task(x, y):
470469
return x / y
471470

472-
with start_span(name="celery"):
471+
with sentry_sdk.start_span(name="celery"):
473472
celery_invocation(dummy_task, 1, 0)
474473

475474
assert not events
@@ -510,7 +509,7 @@ def test_baggage_propagation(init_celery):
510509
def dummy_task(self, x, y):
511510
return _get_headers(self)
512511

513-
with start_span(name="task") as root_span:
512+
with sentry_sdk.start_span(name="task") as root_span:
514513
result = dummy_task.apply_async(
515514
args=(1, 0),
516515
headers={"baggage": "custom=value"},
@@ -520,6 +519,7 @@ def dummy_task(self, x, y):
520519
[
521520
"sentry-release=abcdef",
522521
"sentry-trace_id={}".format(root_span.trace_id),
522+
"sentry-transaction=task",
523523
"sentry-environment=production",
524524
"sentry-sample_rate=1.0",
525525
"sentry-sampled=true",
@@ -537,26 +537,42 @@ def test_sentry_propagate_traces_override(init_celery):
537537
propagate_traces=True, traces_sample_rate=1.0, release="abcdef"
538538
)
539539

540+
# Since we're applying the task inline eagerly,
541+
# we need to cleanup the otel context for this test.
542+
# and since we patch build_tracer, we need to do this before that runs...
543+
# TODO: the right way is to not test this inline
544+
original_apply = Task.apply
545+
546+
def cleaned_apply(*args, **kwargs):
547+
token = context.attach(otel_trace.set_span_in_context(otel_trace.INVALID_SPAN))
548+
rv = original_apply(*args, **kwargs)
549+
context.detach(token)
550+
return rv
551+
552+
Task.apply = cleaned_apply
553+
540554
@celery.task(name="dummy_task", bind=True)
541555
def dummy_task(self, message):
542556
trace_id = get_current_span().trace_id
543557
return trace_id
544558

545-
with start_span(name="task") as root_span:
546-
transaction_trace_id = root_span.trace_id
559+
with sentry_sdk.start_span(name="task") as root_span:
560+
root_span_trace_id = root_span.trace_id
547561

548562
# should propagate trace
549-
task_transaction_id = dummy_task.apply_async(
563+
task_trace_id = dummy_task.apply_async(
550564
args=("some message",),
551565
).get()
552-
assert transaction_trace_id == task_transaction_id
566+
assert root_span_trace_id == task_trace_id, "Trace should be propagated"
553567

554568
# should NOT propagate trace (overrides `propagate_traces` parameter in integration constructor)
555-
task_transaction_id = dummy_task.apply_async(
569+
task_trace_id = dummy_task.apply_async(
556570
args=("another message",),
557571
headers={"sentry-propagate-traces": False},
558572
).get()
559-
assert transaction_trace_id != task_transaction_id
573+
assert root_span_trace_id != task_trace_id, "Trace should NOT be propagated"
574+
575+
Task.apply = original_apply
560576

561577

562578
def test_apply_async_manually_span(sentry_init):
@@ -710,7 +726,7 @@ def publish(*args, **kwargs):
710726
@celery.task()
711727
def task(): ...
712728

713-
with start_span(name="task"):
729+
with sentry_sdk.start_span(name="task"):
714730
task.apply_async()
715731

716732
(event,) = events
@@ -773,7 +789,7 @@ def publish(*args, **kwargs):
773789
@celery.task()
774790
def task(): ...
775791

776-
with start_span(name="custom_transaction"):
792+
with sentry_sdk.start_span(name="custom_transaction"):
777793
task.apply_async()
778794

779795
(event,) = events

0 commit comments

Comments
 (0)