Skip to content

Commit 0e4db42

Browse files
fix(futures) Allow calling ThreadPoolExecutor.submit with fn kwarg (#3035) (#3040)
(cherry picked from commit fb1c220) Co-authored-by: Brett Langdon <[email protected]>
1 parent 2cd147e commit 0e4db42

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

ddtrace/contrib/futures/threading.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def _wrap_submit(func, instance, args, kwargs):
2121
if ddtrace.tracer.context_provider._has_active_context():
2222
current_ctx = ddtrace.tracer.context_provider.active()
2323

24-
# extract the target function that must be executed in
25-
# a new thread and the `target` arguments
26-
fn = args[0]
27-
fn_args = args[1:]
24+
# The target function can be provided as a kwarg argument "fn" or the first positional argument
25+
if "fn" in kwargs:
26+
fn = kwargs.pop("fn")
27+
fn_args = args
28+
else:
29+
fn, fn_args = args[0], args[1:]
2830
return func(_wrap_execution, current_ctx, fn, fn_args, kwargs)
2931

3032

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fixes:
2+
- |
3+
Fix error when calling ``concurrent.futures.ThreadPoolExecutor.submit`` with ``fn`` keyword argument.

tests/contrib/futures/test_propagation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def fn(value, key=None):
6868
(dict(name="executor.thread"),),
6969
)
7070

71+
def test_propagation_with_kwargs(self):
72+
# instrumentation must work if only kwargs are provided
73+
74+
def fn(value, key=None):
75+
with self.tracer.trace("executor.thread"):
76+
return value, key
77+
78+
with self.override_global_tracer():
79+
with self.tracer.trace("main.thread"):
80+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
81+
future = executor.submit(fn=fn, value=42, key="CheeseShop")
82+
value, key = future.result()
83+
# assert the right result
84+
self.assertEqual(value, 42)
85+
self.assertEqual(key, "CheeseShop")
86+
87+
# the trace must be completed
88+
self.assert_structure(
89+
dict(name="main.thread"),
90+
(dict(name="executor.thread"),),
91+
)
92+
7193
def test_disabled_instrumentation(self):
7294
# it must not propagate if the module is disabled
7395
unpatch()

0 commit comments

Comments
 (0)