Skip to content

[MLIR:Python] Fix race on PyOperations. #139721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 136 additions & 37 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,75 @@ class PyOpOperandIterator {
MlirOpOperand opOperand;
};

#if !defined(Py_GIL_DISABLED)
inline void enableTryIncRef(nb::handle obj) noexcept {}
inline bool tryIncRef(nb::handle obj) noexcept {
if (Py_REFCNT(obj.ptr()) > 0) {
Py_INCREF(obj.ptr());
return true;
}
return false;
}

#elif PY_VERSION_HEX >= 0x030E00A5

// CPython 3.14 provides an unstable API for these.
inline void enableTryIncRef(nb::handle obj) noexcept {
PyUnstable_EnableTryIncRef(obj.ptr());
}
inline bool tryIncRef(nb::handle obj) noexcept {
return PyUnstable_TryIncRef(obj.ptr());
}

#else

// For CPython 3.13 there is no API for this, and so we must implement our own.
// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
void enableTryIncRef(nb::handle h) noexcept {
// Since this is called during object construction, we know that we have
// the only reference to the object and can use a non-atomic write.
PyObject *obj = h.ptr();
assert(h->ob_ref_shared == 0);
h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
}

bool tryIncRef(nb::handle h) noexcept {
PyObject *obj = h.ptr();
// See
// https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
local += 1;
if (local == 0) {
// immortal
return true;
}
if (_Py_IsOwnedByCurrentThread(obj)) {
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
#ifdef Py_REF_DEBUG
_Py_INCREF_IncRefTotal();
#endif
return true;
}
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
for (;;) {
// If the shared refcount is zero and the object is either merged
// or may not have weak references, then we cannot incref it.
if (shared == 0 || shared == _Py_REF_MERGED) {
return false;
}

if (_Py_atomic_compare_exchange_ssize(&obj->ob_ref_shared, &shared,
shared +
(1 << _Py_REF_SHARED_SHIFT))) {
#ifdef Py_REF_DEBUG
_Py_INCREF_IncRefTotal();
#endif
return true;
}
}
}
#endif

} // namespace

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
return liveOperations.size();
}

std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
std::vector<PyOperation *> liveObjects;
std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
std::vector<nb::object> liveObjects;
nb::ft_lock_guard lock(liveOperationsMutex);
for (auto &entry : liveOperations)
liveObjects.push_back(entry.second.second);
for (auto &entry : liveOperations) {
// It is not safe to unconditionally increment the reference count here
// because an operation that is in the process of being deleted by another
// thread may still be present in the map.
if (tryIncRef(entry.second.first)) {
liveObjects.push_back(nb::steal(entry.second.first));
}
}
return liveObjects;
}

Expand All @@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
{
nb::ft_lock_guard lock(liveOperationsMutex);
std::swap(operations, liveOperations);
for (auto &op : operations)
op.second.second->setInvalidLocked();
}
for (auto &op : operations)
op.second.second->setInvalid();
size_t numInvalidated = operations.size();
return numInvalidated;
}

void PyMlirContext::clearOperation(MlirOperation op) {
PyOperation *py_op;
{
nb::ft_lock_guard lock(liveOperationsMutex);
auto it = liveOperations.find(op.ptr);
if (it == liveOperations.end()) {
return;
}
py_op = it->second.second;
liveOperations.erase(it);
void PyMlirContext::clearOperationLocked(MlirOperation op) {
auto it = liveOperations.find(op.ptr);
if (it == liveOperations.end()) {
return;
}
py_op->setInvalid();
PyOperation *py_op = it->second.second;
py_op->setInvalidLocked();
liveOperations.erase(it);
}

void PyMlirContext::clearOperation(MlirOperation op) {
nb::ft_lock_guard lock(liveOperationsMutex);
clearOperationLocked(op);
}

void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
Expand Down Expand Up @@ -766,14 +842,14 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
clearOperationsInside(opRef->getOperation());
}

void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
void PyMlirContext::clearOperationAndInsideLocked(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
contextRef->clearOperation(op);
contextRef->clearOperationLocked(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
mlirOperationWalk(op.getOperation().getLocked(), invalidatingCallback,
&op.getOperation().getContext(), MlirWalkPreOrder);
}

Expand Down Expand Up @@ -1203,19 +1279,23 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
: BaseContextObject(std::move(contextRef)), operation(operation) {}

PyOperation::~PyOperation() {
PyMlirContextRef context = getContext();
nb::ft_lock_guard lock(context->liveOperationsMutex);
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;

// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
getContext()->clearOperation(*this);
// Since the operation was valid, we know that it is this object present
// in the map, not some other object.
context->liveOperations.erase(operation.ptr);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
erase();
eraseLocked();
}
}

Expand All @@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
// Create.
PyOperationRef unownedOperation =
makeObjectRef<PyOperation>(std::move(contextRef), operation);
enableTryIncRef(unownedOperation.getObject());
unownedOperation->handle = unownedOperation.getObject();
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
Expand All @@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
// Create.
PyOperationRef result = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
liveOperations[operation.ptr] =
std::make_pair(result.getObject(), result.get());
return result;
if (it != liveOperations.end()) {
PyOperation *existing = it->second.second;
nb::handle pyRef = it->second.first;

// Try to increment the reference count of the existing entry. This can fail
// if the object is in the process of being destroyed by another thread.
if (tryIncRef(pyRef)) {
return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
}

// Mark the existing entry as invalid, since we are about to replace it.
existing->setInvalidLocked();
}
// Use existing.
PyOperation *existing = it->second.second;
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyOperationRef(existing, std::move(pyRef));

// Create a new wrapper object.
PyOperationRef result = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
liveOperations[operation.ptr] =
std::make_pair(result.getObject(), result.get());
return result;
}

PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
Expand Down Expand Up @@ -1297,6 +1386,11 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
}

void PyOperation::checkValid() const {
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
checkValidLocked();
}

void PyOperation::checkValidLocked() const {
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
}
Expand Down Expand Up @@ -1638,12 +1732,17 @@ nb::object PyOperation::createOpView() {
return nb::cast(PyOpView(getRef().getObject()));
}

void PyOperation::erase() {
checkValid();
getContext()->clearOperationAndInside(*this);
void PyOperation::eraseLocked() {
checkValidLocked();
getContext()->clearOperationAndInsideLocked(*this);
mlirOperationDestroy(operation);
}

void PyOperation::erase() {
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
eraseLocked();
}

namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
Expand Down Expand Up @@ -2324,7 +2423,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
// The operation is also erased, so we must invalidate it. There may be Python
// references to this operation so we don't want to delete it from the list of
// live operations here.
symbol.getOperation().valid = false;
symbol.getOperation().setInvalid();
}

void PySymbolTable::dunderDel(const std::string &name) {
Expand Down
40 changes: 32 additions & 8 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PyObjectRef {
}

T *get() { return referrent; }
T *operator->() {
T *operator->() const {
assert(referrent && object);
return referrent;
}
Expand Down Expand Up @@ -229,7 +229,7 @@ class PyMlirContext {
static size_t getLiveCount();

/// Get a list of Python objects which are still in the live context map.
std::vector<PyOperation *> getLiveOperationObjects();
std::vector<nanobind::object> getLiveOperationObjects();

/// Gets the count of live operations associated with this context.
/// Used for testing.
Expand All @@ -254,9 +254,10 @@ class PyMlirContext {
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);

/// Clears the operaiton _and_ all operations inside using
/// `clearOperation(MlirOperation)`.
void clearOperationAndInside(PyOperationBase &op);
/// Clears the operation _and_ all operations inside using
/// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
/// held.
void clearOperationAndInsideLocked(PyOperationBase &op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
Expand All @@ -278,6 +279,9 @@ class PyMlirContext {
struct ErrorCapture;

private:
// Similar to clearOperation, but requires the liveOperations mutex to be held
void clearOperationLocked(MlirOperation op);

// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
Expand All @@ -302,6 +306,9 @@ class PyMlirContext {
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;

// liveOperationsMutex guards both liveOperations and the valid field of
// PyOperation objects in free-threading mode.
nanobind::ft_mutex liveOperationsMutex;

// Guarded by liveOperationsMutex in free-threading mode.
Expand Down Expand Up @@ -336,6 +343,7 @@ class BaseContextObject {
}

/// Accesses the context reference.
const PyMlirContextRef &getContext() const { return contextRef; }
PyMlirContextRef &getContext() { return contextRef; }

private:
Expand Down Expand Up @@ -677,6 +685,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
checkValid();
return operation;
}
MlirOperation getLocked() const {
checkValidLocked();
return operation;
}

PyOperationRef getRef() {
return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
Expand All @@ -692,6 +704,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
attached = false;
}
void checkValid() const;
void checkValidLocked() const;

/// Gets the owning block or raises an exception if the operation has no
/// owning block.
Expand Down Expand Up @@ -725,19 +738,27 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// parent context's live operations map, and sets the valid bit false.
void erase();

/// Invalidate the operation.
void setInvalid() { valid = false; }

/// Clones this operation.
nanobind::object clone(const nanobind::object &ip);

/// Invalidate the operation.
void setInvalid() {
nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
setInvalidLocked();
}
/// Like setInvalid(), but requires the liveOperations mutex to be held.
void setInvalidLocked() { valid = false; }
Copy link
Contributor

@makslevental makslevental May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question: is there a way to put a runtime assert/check here that the mutex is actually held? so that there's some way for people that don't read the doc strings (...like me...) to save themselves via compiling with asserts. e.g. i'm wondering if nanobind::ft_mutex::lock() is a no-op if the mutex is already held/locked by the thread?

if that's too tedious/onerous a change than i propose we rename the method to something like setInvalidWhileLocked (although that implies the method will be a no-op "when unlocked", which is not true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice, wouldn't it.

Sadly PyMutex_IsLocked isn't a public CPython API; I filed python/cpython#134009 for that.

It would probably be possible for nanobind to clone that code, much as we're doing for PyUnstable_TryIncRef, but let me wait until we hear back on the CPython issue before I send them a PR adding a .is_locked() method on the nanobind ft_mutex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should add: there's no particular reason we have to use a nb::ft_mutex here, they aren't particularly special. We could use another mutex that does have this API. However, nb::ft_mutex has the nice property of being present if and only if we are in GIL-disabled mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has the nice property of being present if and only if we are in GIL-disabled mode.

seems special enough to me. we can file this as a TODO and just do the tedious thing instead for now (rename the affected methods).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops should've read everything first - guess you already did rename things. cool


PyOperation(PyMlirContextRef contextRef, MlirOperation operation);

private:
static PyOperationRef createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
nanobind::object parentKeepAlive);

// Like erase(), but requires the caller to hold the liveOperationsMutex.
void eraseLocked();

MlirOperation operation;
nanobind::handle handle;
// Keeps the parent alive, regardless of whether it is an Operation or
Expand All @@ -748,6 +769,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
// ir_operation.py regarding testing corresponding lifetime guarantees.
nanobind::object parentKeepAlive;
bool attached = true;

// Guarded by 'context->liveOperationsMutex'. Valid objects must be present
// in context->liveOperations.
bool valid = true;

friend class PyOperationBase;
Expand Down
Loading