diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 9cdf21f7ad..7a0988ab05 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -54,9 +54,20 @@ struct type_caster> { } } - value = [func](Args... args) -> Return { + // ensure GIL is held during functor destruction + struct func_handle { + function f; + func_handle(function&& f_) : f(std::move(f_)) {} + func_handle(const func_handle&) = default; + ~func_handle() { + gil_scoped_acquire acq; + function kill_f(std::move(f)); + } + }; + + value = [hfunc = func_handle(std::move(func))](Args... args) -> Return { gil_scoped_acquire acq; - object retval(func(std::forward(args)...)); + object retval(hfunc.f(std::forward(args)...)); /* Visual studio 2015 parser issue: need parentheses around this expression */ return (retval.template cast()); }; diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 273eacc30c..71b88c44c7 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -10,6 +10,7 @@ #include "pybind11_tests.h" #include "constructor_stats.h" #include +#include int dummy_function(int i) { return i + 1; } @@ -146,4 +147,22 @@ TEST_SUBMODULE(callbacks, m) { py::class_(m, "CppBoundMethodTest") .def(py::init<>()) .def("triple", [](CppBoundMethodTest &, int val) { return 3 * val; }); + + // test async Python callbacks + using callback_f = std::function; + m.def("test_async_callback", [](callback_f f, py::list work) { + // make detached thread that calls `f` with piece of work after a little delay + auto start_f = [f](int j) { + auto invoke_f = [f, j] { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + f(j); + }; + auto t = std::thread(std::move(invoke_f)); + t.detach(); + }; + + // spawn worker threads + for (auto i : work) + start_f(py::cast(i)); + }); } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 93c42c22b8..6439c8e72a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,5 +1,6 @@ import pytest from pybind11_tests import callbacks as m +from threading import Thread def test_callbacks(): @@ -105,3 +106,31 @@ def test_function_signatures(doc): def test_movable_object(): assert m.callback_with_movable(lambda _: None) is True + + +def test_async_callbacks(): + # serves as state for async callback + class Item: + def __init__(self, value): + self.value = value + + res = [] + + # generate stateful lambda that will store result in `res` + def gen_f(): + s = Item(3) + return lambda j: res.append(s.value + j) + + # do some work async + work = [1, 2, 3, 4] + m.test_async_callback(gen_f(), work) + # wait until work is done + from time import sleep + sleep(0.5) + assert sum(res) == sum([x + 3 for x in work]) + + +def test_async_async_callbacks(): + t = Thread(target=test_async_callbacks) + t.start() + t.join()