Skip to content

Commit 1a07f3c

Browse files
committed
Fix data race all_type_info_populate in free-threading mode
Description: - fixed data race all_type_info_populate in free-threading mode - added test For example, we have 2 threads entering `all_type_info`. Both enter `all_type_info_get_cache`` function and there is a first one which inserts a tuple (type, empty_vector) to the map and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread after waiting gets (iter_to_key, False). Inserting thread than will add a weakref and will then call into `all_type_info_populate`. However, non-inserting thread is not entering `if (ins.second) {` clause and returns `ins.first->second;`` which is just empty_vector. Finally, non-inserting thread is failing the check in `allocate_layout`: ```c++ if (n_types == 0) { pybind11_fail( "instance allocation failed: new instance has no pybind11-registered base types"); } ```
1 parent f7e14e9 commit 1a07f3c

File tree

5 files changed

+77
-10
lines changed

5 files changed

+77
-10
lines changed

include/pybind11/detail/type_caster_base.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
117117
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
118118
check.push_back((PyTypeObject *) parent.ptr());
119119
}
120-
121120
auto const &type_dict = get_internals().registered_types_py;
122121
for (size_t i = 0; i < check.size(); i++) {
123122
auto *type = check[i];
@@ -177,11 +176,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
177176
*/
178177
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
179178
auto ins = all_type_info_get_cache(type);
180-
if (ins.second) {
181-
// New cache entry: populate it
182-
all_type_info_populate(type, ins.first->second);
183-
}
184-
185179
return ins.first->second;
186180
}
187181

include/pybind11/pybind11.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) {
23262326
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
23272327
all_type_info_get_cache(PyTypeObject *type) {
23282328
auto res = with_internals([type](internals &internals) {
2329-
return internals
2330-
.registered_types_py
2329+
auto ins = internals
2330+
.registered_types_py
23312331
#ifdef __cpp_lib_unordered_map_try_emplace
2332-
.try_emplace(type);
2332+
.try_emplace(type);
23332333
#else
2334-
.emplace(type, std::vector<detail::type_info *>());
2334+
.emplace(type, std::vector<detail::type_info *>());
23352335
#endif
2336+
if (ins.second) {
2337+
// In free-threading this method should be called
2338+
// under pymutex lock to avoid other threads
2339+
// continue running with empty ins.first->second
2340+
all_type_info_populate(type, ins.first->second);
2341+
}
2342+
return ins;
23362343
});
23372344
if (res.second) {
23382345
// New cache entry created; set up a weak reference to automatically remove it if the type

tests/pybind11_tests.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) {
128128
for (const auto &initializer : initializers()) {
129129
initializer(m);
130130
}
131+
132+
py::class_<TestContext>(m, "TestContext")
133+
.def(py::init<>(&TestContext::createNewContextForInit))
134+
.def("__enter__", &TestContext::contextEnter)
135+
.def("__exit__", &TestContext::contextExit);
131136
}

tests/pybind11_tests.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,23 @@ void ignoreOldStyleInitWarnings(F &&body) {
9696
)",
9797
py::dict(py::arg("body") = py::cpp_function(body)));
9898
}
99+
100+
class TestContext {
101+
public:
102+
TestContext() = delete;
103+
TestContext(const TestContext &) = delete;
104+
TestContext(TestContext &&) = delete;
105+
static TestContext *createNewContextForInit() { return new TestContext("new-context"); }
106+
107+
pybind11::object contextEnter() {
108+
py::object contextObj = py::cast(*this);
109+
return contextObj;
110+
}
111+
void contextExit(const pybind11::object & /*excType*/,
112+
const pybind11::object & /*excVal*/,
113+
const pybind11::object & /*excTb*/) {}
114+
115+
private:
116+
TestContext(std::string context) : context(context) {}
117+
std::string context;
118+
};

tests/test_class.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,44 @@ def test_pr4220_tripped_over_this():
501501
m.Empty0().get_msg()
502502
== "This is really only meant to exercise successful compilation."
503503
)
504+
505+
506+
def test_all_type_info_multithreaded():
507+
# Test data race in all_type_info method in free-threading mode.
508+
# For example, we have 2 threads entering `all_type_info`.
509+
# Both enter `all_type_info_get_cache`` function and
510+
# there is a first one which inserts a tuple (type, empty_vector) to the map
511+
# and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread
512+
# after waiting gets (iter_to_key, False).
513+
# Inserting thread than will add a weakref and will then call into `all_type_info_populate`.
514+
# However, non-inserting thread is not entering `if (ins.second) {` clause and
515+
# returns `ins.first->second;`` which is just empty_vector.
516+
# Finally, non-inserting thread is failing the check in `allocate_layout`:
517+
# if (n_types == 0) {
518+
# pybind11_fail(
519+
# "instance allocation failed: new instance has no pybind11-registered base types");
520+
# }
521+
import threading
522+
523+
from pybind11_tests import TestContext
524+
525+
class Context(TestContext):
526+
def __init__(self, *args, **kwargs):
527+
super().__init__(*args, **kwargs)
528+
529+
num_runs = 4
530+
num_threads = 5
531+
barrier = threading.Barrier(num_threads)
532+
533+
def func():
534+
barrier.wait()
535+
with Context():
536+
pass
537+
538+
for _ in range(num_runs):
539+
threads = [threading.Thread(target=func) for _ in range(num_threads)]
540+
for thread in threads:
541+
thread.start()
542+
543+
for thread in threads:
544+
thread.join()

0 commit comments

Comments
 (0)