diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b5720b7ad8b21..076abb6c54249 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -635,6 +635,75 @@ class PyOpOperandIterator { MlirOpOperand opOperand; }; +#if !defined(Py_GIL_DISABLED) +inline void enableTryIncRef(nb::handle obj) noexcept {} +inline bool tryIncRef(nb::handle obj) noexcept { + if (Py_REFCNT(obj.ptr()) > 0) { + Py_INCREF(obj.ptr()); + return true; + } + return false; +} + +#elif PY_VERSION_HEX >= 0x030E00A5 + +// CPython 3.14 provides an unstable API for these. +inline void enableTryIncRef(nb::handle obj) noexcept { + PyUnstable_EnableTryIncRef(obj.ptr()); +} +inline bool tryIncRef(nb::handle obj) noexcept { + return PyUnstable_TryIncRef(obj.ptr()); +} + +#else + +// For CPython 3.13 there is no API for this, and so we must implement our own. +// This code originates from https://github.com/wjakob/nanobind/pull/865/files. +void enableTryIncRef(nb::handle h) noexcept { + // Since this is called during object construction, we know that we have + // the only reference to the object and can use a non-atomic write. + PyObject *obj = h.ptr(); + assert(h->ob_ref_shared == 0); + h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF; +} + +bool tryIncRef(nb::handle h) noexcept { + PyObject *obj = h.ptr(); + // See + // https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761 + uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local); + local += 1; + if (local == 0) { + // immortal + return true; + } + if (_Py_IsOwnedByCurrentThread(obj)) { + _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local); +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return true; + } + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared); + for (;;) { + // If the shared refcount is zero and the object is either merged + // or may not have weak references, then we cannot incref it. + if (shared == 0 || shared == _Py_REF_MERGED) { + return false; + } + + if (_Py_atomic_compare_exchange_ssize(&obj->ob_ref_shared, &shared, + shared + + (1 << _Py_REF_SHARED_SHIFT))) { +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return true; + } + } +} +#endif + } // namespace //------------------------------------------------------------------------------ @@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } -std::vector PyMlirContext::getLiveOperationObjects() { - std::vector liveObjects; +std::vector PyMlirContext::getLiveOperationObjects() { + std::vector liveObjects; nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); + for (auto &entry : liveOperations) { + // It is not safe to unconditionally increment the reference count here + // because an operation that is in the process of being deleted by another + // thread may still be present in the map. + if (tryIncRef(entry.second.first)) { + liveObjects.push_back(nb::steal(entry.second.first)); + } + } return liveObjects; } @@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() { { nb::ft_lock_guard lock(liveOperationsMutex); std::swap(operations, liveOperations); + for (auto &op : operations) + op.second.second->setInvalidLocked(); } - for (auto &op : operations) - op.second.second->setInvalid(); size_t numInvalidated = operations.size(); return numInvalidated; } -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *py_op; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - py_op = it->second.second; - liveOperations.erase(it); +void PyMlirContext::clearOperationLocked(MlirOperation op) { + auto it = liveOperations.find(op.ptr); + if (it == liveOperations.end()) { + return; } - py_op->setInvalid(); + PyOperation *py_op = it->second.second; + py_op->setInvalidLocked(); + liveOperations.erase(it); +} + +void PyMlirContext::clearOperation(MlirOperation op) { + nb::ft_lock_guard lock(liveOperationsMutex); + clearOperationLocked(op); } void PyMlirContext::clearOperationsInside(PyOperationBase &op) { @@ -766,14 +842,14 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) { clearOperationsInside(opRef->getOperation()); } -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { +void PyMlirContext::clearOperationAndInsideLocked(PyOperationBase &op) { MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, void *userData) { PyMlirContextRef &contextRef = *static_cast(userData); - contextRef->clearOperation(op); + contextRef->clearOperationLocked(op); return MlirWalkResult::MlirWalkResultAdvance; }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, + mlirOperationWalk(op.getOperation().getLocked(), invalidatingCallback, &op.getOperation().getContext(), MlirWalkPreOrder); } @@ -1203,6 +1279,8 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) : BaseContextObject(std::move(contextRef)), operation(operation) {} PyOperation::~PyOperation() { + PyMlirContextRef context = getContext(); + nb::ft_lock_guard lock(context->liveOperationsMutex); // If the operation has already been invalidated there is nothing to do. if (!valid) return; @@ -1210,12 +1288,14 @@ PyOperation::~PyOperation() { // Otherwise, invalidate the operation and remove it from live map when it is // attached. if (isAttached()) { - getContext()->clearOperation(*this); + // Since the operation was valid, we know that it is this object present + // in the map, not some other object. + context->liveOperations.erase(operation.ptr); } else { // And destroy it when it is detached, i.e. owned by Python, in which case // all nested operations must be invalidated at removed from the live map as // well. - erase(); + eraseLocked(); } } @@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, // Create. PyOperationRef unownedOperation = makeObjectRef(std::move(contextRef), operation); + enableTryIncRef(unownedOperation.getObject()); unownedOperation->handle = unownedOperation.getObject(); if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); @@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, nb::ft_lock_guard lock(contextRef->liveOperationsMutex); auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; + if (it != liveOperations.end()) { + PyOperation *existing = it->second.second; + nb::handle pyRef = it->second.first; + + // Try to increment the reference count of the existing entry. This can fail + // if the object is in the process of being destroyed by another thread. + if (tryIncRef(pyRef)) { + return PyOperationRef(existing, nb::steal(pyRef)); + } + + // Mark the existing entry as invalid, since we are about to replace it. + existing->setInvalidLocked(); } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + + // Create a new wrapper object. + PyOperationRef result = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + liveOperations[operation.ptr] = + std::make_pair(result.getObject(), result.get()); + return result; } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, @@ -1297,6 +1386,11 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, } void PyOperation::checkValid() const { + nb::ft_lock_guard lock(getContext()->liveOperationsMutex); + checkValidLocked(); +} + +void PyOperation::checkValidLocked() const { if (!valid) { throw std::runtime_error("the operation has been invalidated"); } @@ -1638,12 +1732,17 @@ nb::object PyOperation::createOpView() { return nb::cast(PyOpView(getRef().getObject())); } -void PyOperation::erase() { - checkValid(); - getContext()->clearOperationAndInside(*this); +void PyOperation::eraseLocked() { + checkValidLocked(); + getContext()->clearOperationAndInsideLocked(*this); mlirOperationDestroy(operation); } +void PyOperation::erase() { + nb::ft_lock_guard lock(getContext()->liveOperationsMutex); + eraseLocked(); +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2324,7 +2423,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) { // The operation is also erased, so we must invalidate it. There may be Python // references to this operation so we don't want to delete it from the list of // live operations here. - symbol.getOperation().valid = false; + symbol.getOperation().setInvalid(); } void PySymbolTable::dunderDel(const std::string &name) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9befcce725bb7..c823b6deb4b26 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -83,7 +83,7 @@ class PyObjectRef { } T *get() { return referrent; } - T *operator->() { + T *operator->() const { assert(referrent && object); return referrent; } @@ -229,7 +229,7 @@ class PyMlirContext { static size_t getLiveCount(); /// Get a list of Python objects which are still in the live context map. - std::vector getLiveOperationObjects(); + std::vector getLiveOperationObjects(); /// Gets the count of live operations associated with this context. /// Used for testing. @@ -254,9 +254,10 @@ class PyMlirContext { void clearOperationsInside(PyOperationBase &op); void clearOperationsInside(MlirOperation op); - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); + /// Clears the operation _and_ all operations inside using + /// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is + /// held. + void clearOperationAndInsideLocked(PyOperationBase &op); /// Gets the count of live modules associated with this context. /// Used for testing. @@ -278,6 +279,9 @@ class PyMlirContext { struct ErrorCapture; private: + // Similar to clearOperation, but requires the liveOperations mutex to be held + void clearOperationLocked(MlirOperation op); + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an @@ -302,6 +306,9 @@ class PyMlirContext { // attempt to access it will raise an error. using LiveOperationMap = llvm::DenseMap>; + + // liveOperationsMutex guards both liveOperations and the valid field of + // PyOperation objects in free-threading mode. nanobind::ft_mutex liveOperationsMutex; // Guarded by liveOperationsMutex in free-threading mode. @@ -336,6 +343,7 @@ class BaseContextObject { } /// Accesses the context reference. + const PyMlirContextRef &getContext() const { return contextRef; } PyMlirContextRef &getContext() { return contextRef; } private: @@ -677,6 +685,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject { checkValid(); return operation; } + MlirOperation getLocked() const { + checkValidLocked(); + return operation; + } PyOperationRef getRef() { return PyOperationRef(this, nanobind::borrow(handle)); @@ -692,6 +704,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { attached = false; } void checkValid() const; + void checkValidLocked() const; /// Gets the owning block or raises an exception if the operation has no /// owning block. @@ -725,12 +738,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// parent context's live operations map, and sets the valid bit false. void erase(); - /// Invalidate the operation. - void setInvalid() { valid = false; } - /// Clones this operation. nanobind::object clone(const nanobind::object &ip); + /// Invalidate the operation. + void setInvalid() { + nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex); + setInvalidLocked(); + } + /// Like setInvalid(), but requires the liveOperations mutex to be held. + void setInvalidLocked() { valid = false; } + PyOperation(PyMlirContextRef contextRef, MlirOperation operation); private: @@ -738,6 +756,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject { MlirOperation operation, nanobind::object parentKeepAlive); + // Like erase(), but requires the caller to hold the liveOperationsMutex. + void eraseLocked(); + MlirOperation operation; nanobind::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or @@ -748,6 +769,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject { // ir_operation.py regarding testing corresponding lifetime guarantees. nanobind::object parentKeepAlive; bool attached = true; + + // Guarded by 'context->liveOperationsMutex'. Valid objects must be present + // in context->liveOperations. bool valid = true; friend class PyOperationBase; diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py index 6e1a668346872..0c74e6c5d74f4 100644 --- a/mlir/test/python/multithreaded_tests.py +++ b/mlir/test/python/multithreaded_tests.py @@ -40,6 +40,7 @@ import importlib.util import os import sys +import textwrap import threading import tempfile import unittest @@ -512,6 +513,51 @@ def _original_test_create_module_with_consts(self): arith.constant(dtype, py_values[2]) + def test_check_pyoperation_race(self): + # Regression test for a race where: + # * one thread is in the process of destroying a PyOperation, + # * while simultaneously another thread looks up the PyOperation is + # the liveOperations map and attempts to increase its reference count. + # It is illegal to attempt to revive an object that is in the process of + # being deleted, and this was producing races and heap use-after-frees. + num_workers = 40 + num_runs = 20 + + barrier = threading.Barrier(num_workers) + + def walk_operations(op): + _ = op.operation.name + for region in op.operation.regions: + for block in region: + for op in block: + walk_operations(op) + + with Context(): + mlir_module = Module.parse( + textwrap.dedent( + """ + module @m { + func.func public @main(%arg0: tensor) -> (tensor) { + return %arg0 : tensor + } + } + """ + ) + ) + + def closure(): + barrier.wait() + + for _ in range(num_runs): + walk_operations(mlir_module) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for i in range(num_workers): + futures.append(executor.submit(closure)) + assert len(list(f.result() for f in futures)) == num_workers + + if __name__ == "__main__": # Do not run the tests on CPython with GIL if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():