diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 01678a9719f90..f03c540d618cd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -684,6 +684,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) { clearOperationsInside(opRef->getOperation()); } +void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + PyMlirContextRef &contextRef = *static_cast(userData); + contextRef->clearOperation(op); + return MlirWalkResult::MlirWalkResultAdvance; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + &op.getOperation().getContext(), MlirWalkPreOrder); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { @@ -1112,12 +1123,16 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - auto &liveOperations = getContext()->liveOperations; - assert(liveOperations.count(operation.ptr) == 1 && - "destroying operation not in live map"); - liveOperations.erase(operation.ptr); - if (!isAttached()) { - mlirOperationDestroy(operation); + + // Otherwise, invalidate the operation and remove it from live map when it is + // attached. + if (isAttached()) { + getContext()->clearOperation(*this); + } else { + // And destroy it when it is detached, i.e. owned by Python, in which case + // all nested operations must be invalidated at removed from the live map as + // well. + erase(); } } @@ -1527,14 +1542,8 @@ py::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - // TODO: Fix memory hazards when erasing a tree of operations for which a deep - // Python reference to a child operation is live. All children should also - // have their `valid` bit set to false. - auto &liveOperations = getContext()->liveOperations; - if (liveOperations.count(operation.ptr)) - liveOperations.erase(operation.ptr); + getContext()->clearOperationAndInside(*this); mlirOperationDestroy(operation); - valid = false; } //------------------------------------------------------------------------------ diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b038a0c54d29b..8c34c11f70950 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,6 +218,8 @@ class PyMlirContext { /// This is useful for when some non-bindings code destroys the operation and /// the bindings need to made aware. For example, in the case when pass /// manager is run. + /// + /// Note that this does *NOT* clear the nested operations. void clearOperation(MlirOperation op); /// Clears all operations nested inside the given op using @@ -225,6 +227,10 @@ class PyMlirContext { void clearOperationsInside(PyOperationBase &op); void clearOperationsInside(MlirOperation op); + /// Clears the operaiton _and_ all operations inside using + /// `clearOperation(MlirOperation)`. + void clearOperationAndInside(PyOperationBase &op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -246,6 +252,7 @@ class PyMlirContext { private: PyMlirContext(MlirContext context); + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an diff --git a/mlir/test/python/live_operations.py b/mlir/test/python/live_operations.py new file mode 100644 index 0000000000000..892ed1715f6c7 --- /dev/null +++ b/mlir/test/python/live_operations.py @@ -0,0 +1,46 @@ +# RUN: %PYTHON %s +# It is sufficient that this doesn't assert. + +from mlir.ir import * + + +def createDetachedModule(): + module = Module.create() + with InsertionPoint(module.body): + # TODO: Python bindings are currently unaware that modules are also + # operations, so having a module erased won't trigger the cascading + # removal of live operations (#93337). Use a non-module operation + # instead. + nested = Operation.create("test.some_operation", regions=1) + + # When the operation is detached from parent, it is considered to be + # owned by Python. It will therefore be erased when the Python object + # is destroyed. + nested.detach_from_parent() + + # However, we create and maintain references to operations within + # `nested`. These references keep the corresponding operations in the + # "live" list even if they have been erased in C++, making them + # "zombie". If the C++ allocator reuses one of the address previously + # used for a now-"zombie" operation, this used to result in an + # assertion "cannot create detached operation that already exists" from + # the bindings code. Erasing the detached operation should result in + # removing all nested operations from the live list. + # + # Note that the assertion is not guaranteed since it depends on the + # behavior of the allocator on the C++ side, so this test mail fail + # intermittently. + with InsertionPoint(nested.regions[0].blocks.append()): + a = [Operation.create("test.some_other_operation") for i in range(100)] + return a + + +def createManyDetachedModules(): + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + for j in range(100): + a = createDetachedModule() + + +if __name__ == "__main__": + createManyDetachedModules()