Skip to content

Commit 5b61a86

Browse files
committed
[MLIR:Python] Fix race on PyOperations.
Joint work with @vfdev-5 We found the following TSAN race report in JAX's CI: jax-ml/jax#28551 ``` WARNING: ThreadSanitizer: data race (pid=35893) Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0): #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54) #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d) #2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d) #3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d) ... Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0): #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54) #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012) #2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54) #3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54) #4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54) #5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b) #6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422) #7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54) ... ``` At the simplest level, the `valid` field of a PyOperation must be protected by a lock, because it may be concurrently accessed from multiple threads. Much more interesting, however is how we get into the situation described by the two stack traces above in the first place. The scenario that triggers this is the following: * thread T56 holds the last Python reference on a PyOperation, and decides to release it. * After T56 starts to release its reference, but before T56 removes the PyOperation from the liveOperations map a second thread T57 comes along and looks up the same MlirOperation in the liveOperations map. * Finding the operation to be present, thread T57 increments the reference count of that PyOperation and returns it to the caller. This is illegal! Python is in the process of calling the destructor of that object, and once an object is in that state it cannot be safely revived. To fix this, whenever we increment the reference count of a PyOperation that we found via the liveOperations map and to which we only hold a non-owning reference, we must use the Python 3.14+ API `PyUnstable_TryIncRef`, which exists precisely for this scenario (python/cpython#128844). That API does not exist under Python 3.13, so we need a backport of it in that case, for which we the backport that both nanobind and pybind11 also use. Fixes jax-ml/jax#28551
1 parent 2d287f5 commit 5b61a86

File tree

3 files changed

+199
-40
lines changed

3 files changed

+199
-40
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

+128-33
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,75 @@ class PyOpOperandIterator {
635635
MlirOpOperand opOperand;
636636
};
637637

638+
639+
640+
#if !defined(Py_GIL_DISABLED)
641+
inline void enableTryIncRef(nb::handle obj) noexcept { }
642+
inline bool tryIncRef(nb::handle obj) noexcept {
643+
if (Py_REFCNT(obj.ptr()) > 0) {
644+
Py_INCREF(obj.ptr());
645+
return true;
646+
}
647+
return false;
648+
}
649+
650+
#elif PY_VERSION_HEX >= 0x030E00A5
651+
652+
// CPython 3.14 provides an unstable API for these.
653+
inline void enableTryIncRef(nb::handle obj) noexcept {
654+
PyUnstable_EnableTryIncRef(obj.ptr());
655+
}
656+
inline bool tryIncRef(nb::handle obj) noexcept {
657+
return PyUnstable_TryIncRef(obj.ptr());
658+
}
659+
660+
#else
661+
662+
// For CPython 3.13 there is no API for this, and so we must implement our own.
663+
// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
664+
void enableTryIncRef(nb::handle h) noexcept {
665+
// Since this is called during object construction, we know that we have
666+
// the only reference to the object and can use a non-atomic write.
667+
PyObject* obj = h.ptr();
668+
assert(h->ob_ref_shared == 0);
669+
h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
670+
}
671+
672+
bool tryIncRef(nb::handle h) noexcept {
673+
PyObject *obj = h.ptr();
674+
// See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
675+
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
676+
local += 1;
677+
if (local == 0) {
678+
// immortal
679+
return true;
680+
}
681+
if (_Py_IsOwnedByCurrentThread(obj)) {
682+
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
683+
#ifdef Py_REF_DEBUG
684+
_Py_INCREF_IncRefTotal();
685+
#endif
686+
return true;
687+
}
688+
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
689+
for (;;) {
690+
// If the shared refcount is zero and the object is either merged
691+
// or may not have weak references, then we cannot incref it.
692+
if (shared == 0 || shared == _Py_REF_MERGED) {
693+
return false;
694+
}
695+
696+
if (_Py_atomic_compare_exchange_ssize(
697+
&obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
698+
#ifdef Py_REF_DEBUG
699+
_Py_INCREF_IncRefTotal();
700+
#endif
701+
return true;
702+
}
703+
}
704+
}
705+
#endif
706+
638707
} // namespace
639708

640709
//------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
706775
return liveOperations.size();
707776
}
708777

709-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710-
std::vector<PyOperation *> liveObjects;
778+
std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
779+
std::vector<nb::object> liveObjects;
711780
nb::ft_lock_guard lock(liveOperationsMutex);
712-
for (auto &entry : liveOperations)
713-
liveObjects.push_back(entry.second.second);
781+
for (auto &entry : liveOperations) {
782+
// It is not safe to unconditionally increment the reference count here
783+
// because an operation that is in the process of being deleted by another
784+
// thread may still be present in the map.
785+
if (tryIncRef(entry.second.first)) {
786+
liveObjects.push_back(nb::steal(entry.second.first));
787+
}
788+
}
714789
return liveObjects;
715790
}
716791

@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
720795
{
721796
nb::ft_lock_guard lock(liveOperationsMutex);
722797
std::swap(operations, liveOperations);
798+
for (auto &op : operations)
799+
op.second.second->setInvalidLocked();
723800
}
724-
for (auto &op : operations)
725-
op.second.second->setInvalid();
726801
size_t numInvalidated = operations.size();
727802
return numInvalidated;
728803
}
729804

730-
void PyMlirContext::clearOperation(MlirOperation op) {
731-
PyOperation *py_op;
732-
{
733-
nb::ft_lock_guard lock(liveOperationsMutex);
734-
auto it = liveOperations.find(op.ptr);
735-
if (it == liveOperations.end()) {
736-
return;
737-
}
738-
py_op = it->second.second;
739-
liveOperations.erase(it);
805+
void PyMlirContext::clearOperationLocked(MlirOperation op) {
806+
auto it = liveOperations.find(op.ptr);
807+
if (it == liveOperations.end()) {
808+
return;
740809
}
741-
py_op->setInvalid();
810+
PyOperation *py_op = it->second.second;
811+
py_op->setInvalidLocked();
812+
liveOperations.erase(it);
813+
}
814+
815+
void PyMlirContext::clearOperation(MlirOperation op) {
816+
nb::ft_lock_guard lock(liveOperationsMutex);
817+
clearOperationLocked(op);
742818
}
743819

744820
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -770,7 +846,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
770846
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771847
void *userData) {
772848
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773-
contextRef->clearOperation(op);
849+
contextRef->clearOperationLocked(op);
774850
return MlirWalkResult::MlirWalkResultAdvance;
775851
};
776852
mlirOperationWalk(op.getOperation(), invalidatingCallback,
@@ -1203,19 +1279,23 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
12031279
: BaseContextObject(std::move(contextRef)), operation(operation) {}
12041280

12051281
PyOperation::~PyOperation() {
1282+
PyMlirContextRef context = getContext();
1283+
nb::ft_lock_guard lock(context->liveOperationsMutex);
12061284
// If the operation has already been invalidated there is nothing to do.
12071285
if (!valid)
12081286
return;
12091287

12101288
// Otherwise, invalidate the operation and remove it from live map when it is
12111289
// attached.
12121290
if (isAttached()) {
1213-
getContext()->clearOperation(*this);
1291+
// Since the operation was valid, we know that it is this object present
1292+
// in the map, not some other object.
1293+
context->liveOperations.erase(operation.ptr);
12141294
} else {
12151295
// And destroy it when it is detached, i.e. owned by Python, in which case
12161296
// all nested operations must be invalidated at removed from the live map as
12171297
// well.
1218-
erase();
1298+
eraseLocked();
12191299
}
12201300
}
12211301

@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12411321
// Create.
12421322
PyOperationRef unownedOperation =
12431323
makeObjectRef<PyOperation>(std::move(contextRef), operation);
1324+
enableTryIncRef(unownedOperation.getObject());
12441325
unownedOperation->handle = unownedOperation.getObject();
12451326
if (parentKeepAlive) {
12461327
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12541335
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
12551336
auto &liveOperations = contextRef->liveOperations;
12561337
auto it = liveOperations.find(operation.ptr);
1257-
if (it == liveOperations.end()) {
1258-
// Create.
1259-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1260-
std::move(parentKeepAlive));
1261-
liveOperations[operation.ptr] =
1262-
std::make_pair(result.getObject(), result.get());
1263-
return result;
1338+
if (it != liveOperations.end()) {
1339+
PyOperation *existing = it->second.second;
1340+
nb::handle pyRef = it->second.first;
1341+
1342+
// Try to increment the reference count of the existing entry. This can fail
1343+
// if the object is in the process of being destroyed by another thread.
1344+
if (tryIncRef(pyRef)) {
1345+
return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
1346+
}
1347+
1348+
// Mark the existing entry as invalid, since we are about to replace it.
1349+
existing->valid = false;
12641350
}
1265-
// Use existing.
1266-
PyOperation *existing = it->second.second;
1267-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268-
return PyOperationRef(existing, std::move(pyRef));
1351+
1352+
// Create a new wrapper object.
1353+
PyOperationRef result = createInstance(std::move(contextRef), operation,
1354+
std::move(parentKeepAlive));
1355+
liveOperations[operation.ptr] =
1356+
std::make_pair(result.getObject(), result.get());
1357+
return result;
12691358
}
12701359

12711360
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1297,6 +1386,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12971386
}
12981387

12991388
void PyOperation::checkValid() const {
1389+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
13001390
if (!valid) {
13011391
throw std::runtime_error("the operation has been invalidated");
13021392
}
@@ -1638,12 +1728,17 @@ nb::object PyOperation::createOpView() {
16381728
return nb::cast(PyOpView(getRef().getObject()));
16391729
}
16401730

1641-
void PyOperation::erase() {
1731+
void PyOperation::eraseLocked() {
16421732
checkValid();
16431733
getContext()->clearOperationAndInside(*this);
16441734
mlirOperationDestroy(operation);
16451735
}
16461736

1737+
void PyOperation::erase() {
1738+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
1739+
eraseLocked();
1740+
}
1741+
16471742
namespace {
16481743
/// CRTP base class for Python MLIR values that subclass Value and should be
16491744
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2419,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23242419
// The operation is also erased, so we must invalidate it. There may be Python
23252420
// references to this operation so we don't want to delete it from the list of
23262421
// live operations here.
2327-
symbol.getOperation().valid = false;
2422+
symbol.getOperation().setInvalid();
23282423
}
23292424

23302425
void PySymbolTable::dunderDel(const std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

+28-7
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class PyObjectRef {
8383
}
8484

8585
T *get() { return referrent; }
86-
T *operator->() {
86+
T *operator->() const {
8787
assert(referrent && object);
8888
return referrent;
8989
}
@@ -229,7 +229,7 @@ class PyMlirContext {
229229
static size_t getLiveCount();
230230

231231
/// Get a list of Python objects which are still in the live context map.
232-
std::vector<PyOperation *> getLiveOperationObjects();
232+
std::vector<nanobind::object> getLiveOperationObjects();
233233

234234
/// Gets the count of live operations associated with this context.
235235
/// Used for testing.
@@ -254,8 +254,9 @@ class PyMlirContext {
254254
void clearOperationsInside(PyOperationBase &op);
255255
void clearOperationsInside(MlirOperation op);
256256

257-
/// Clears the operaiton _and_ all operations inside using
258-
/// `clearOperation(MlirOperation)`.
257+
/// Clears the operation _and_ all operations inside using
258+
/// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
259+
/// held.
259260
void clearOperationAndInside(PyOperationBase &op);
260261

261262
/// Gets the count of live modules associated with this context.
@@ -278,6 +279,9 @@ class PyMlirContext {
278279
struct ErrorCapture;
279280

280281
private:
282+
// Similar to clearOperation, but requires the liveOperations mutex to be held
283+
void clearOperationLocked(MlirOperation op);
284+
281285
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
282286
// preserving the relationship that an MlirContext maps to a single
283287
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -302,6 +306,9 @@ class PyMlirContext {
302306
// attempt to access it will raise an error.
303307
using LiveOperationMap =
304308
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
309+
310+
// liveOperationsMutex guards both liveOperations and the valid field of
311+
// PyOperation objects in free-threading mode.
305312
nanobind::ft_mutex liveOperationsMutex;
306313

307314
// Guarded by liveOperationsMutex in free-threading mode.
@@ -336,6 +343,7 @@ class BaseContextObject {
336343
}
337344

338345
/// Accesses the context reference.
346+
const PyMlirContextRef &getContext() const { return contextRef; }
339347
PyMlirContextRef &getContext() { return contextRef; }
340348

341349
private:
@@ -725,19 +733,29 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
725733
/// parent context's live operations map, and sets the valid bit false.
726734
void erase();
727735

728-
/// Invalidate the operation.
729-
void setInvalid() { valid = false; }
730-
731736
/// Clones this operation.
732737
nanobind::object clone(const nanobind::object &ip);
733738

739+
/// Invalidate the operation.
740+
void setInvalid() {
741+
nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
742+
setInvalidLocked();
743+
}
744+
/// Like setInvalid(), but requires the liveOperations mutex to be held.
745+
void setInvalidLocked() {
746+
valid = false;
747+
}
748+
734749
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
735750

736751
private:
737752
static PyOperationRef createInstance(PyMlirContextRef contextRef,
738753
MlirOperation operation,
739754
nanobind::object parentKeepAlive);
740755

756+
// Like erase(), but requires the caller to hold the liveOperationsMutex.
757+
void eraseLocked();
758+
741759
MlirOperation operation;
742760
nanobind::handle handle;
743761
// Keeps the parent alive, regardless of whether it is an Operation or
@@ -748,6 +766,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
748766
// ir_operation.py regarding testing corresponding lifetime guarantees.
749767
nanobind::object parentKeepAlive;
750768
bool attached = true;
769+
770+
// Guarded by 'context->liveOperationsMutex'. Valid objects must be present
771+
// in context->liveOperations.
751772
bool valid = true;
752773

753774
friend class PyOperationBase;

0 commit comments

Comments
 (0)