You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Joint work with @vfdev-5
We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551
```
WARNING: ThreadSanitizer: data race (pid=35893)
Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
#0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
#1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...
Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
#0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
#1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```
At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.
The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
decides to release it.
* After T56 starts to release its reference, but before T56 removes the
PyOperation from the liveOperations map a second thread T57 comes
along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
reference count of that PyOperation and returns it to the caller.
This is illegal! Python is in the process of calling the destructor of
that object, and once an object is in that state it cannot be safely
revived.
To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.
Fixesjax-ml/jax#28551
0 commit comments