Skip to content

Commit 67897d7

Browse files
authored
[mlir][py] invalidate nested operations when parent is deleted (#93339)
When an operation is erased in Python, its children may still be in the "live" list inside Python bindings. After this, if some of the newly allocated operations happen to reuse the same pointer address, this will trigger an assertion in the bindings. This assertion would be incorrect because the operations aren't actually live. Make sure we remove the children operations from the "live" list when erasing the parent. This also concentrates responsibility over the removal from the "live" list and invalidation in a single place. Note that this requires the IR to be sufficiently structurally valid so a walk through it can succeed. If this invariant was broken by, e.g, C++ pass called from Python, there isn't much we can do.
1 parent 540a36a commit 67897d7

File tree

3 files changed

+75
-13
lines changed

3 files changed

+75
-13
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

+22-13
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
697697
clearOperationsInside(opRef->getOperation());
698698
}
699699

700+
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
701+
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
702+
void *userData) {
703+
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
704+
contextRef->clearOperation(op);
705+
return MlirWalkResult::MlirWalkResultAdvance;
706+
};
707+
mlirOperationWalk(op.getOperation(), invalidatingCallback,
708+
&op.getOperation().getContext(), MlirWalkPreOrder);
709+
}
710+
700711
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
701712

702713
pybind11::object PyMlirContext::contextEnter() {
@@ -1125,12 +1136,16 @@ PyOperation::~PyOperation() {
11251136
// If the operation has already been invalidated there is nothing to do.
11261137
if (!valid)
11271138
return;
1128-
auto &liveOperations = getContext()->liveOperations;
1129-
assert(liveOperations.count(operation.ptr) == 1 &&
1130-
"destroying operation not in live map");
1131-
liveOperations.erase(operation.ptr);
1132-
if (!isAttached()) {
1133-
mlirOperationDestroy(operation);
1139+
1140+
// Otherwise, invalidate the operation and remove it from live map when it is
1141+
// attached.
1142+
if (isAttached()) {
1143+
getContext()->clearOperation(*this);
1144+
} else {
1145+
// And destroy it when it is detached, i.e. owned by Python, in which case
1146+
// all nested operations must be invalidated at removed from the live map as
1147+
// well.
1148+
erase();
11341149
}
11351150
}
11361151

@@ -1540,14 +1555,8 @@ py::object PyOperation::createOpView() {
15401555

15411556
void PyOperation::erase() {
15421557
checkValid();
1543-
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
1544-
// Python reference to a child operation is live. All children should also
1545-
// have their `valid` bit set to false.
1546-
auto &liveOperations = getContext()->liveOperations;
1547-
if (liveOperations.count(operation.ptr))
1548-
liveOperations.erase(operation.ptr);
1558+
getContext()->clearOperationAndInside(*this);
15491559
mlirOperationDestroy(operation);
1550-
valid = false;
15511560
}
15521561

15531562
//------------------------------------------------------------------------------

mlir/lib/Bindings/Python/IRModule.h

+7
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,19 @@ class PyMlirContext {
218218
/// This is useful for when some non-bindings code destroys the operation and
219219
/// the bindings need to made aware. For example, in the case when pass
220220
/// manager is run.
221+
///
222+
/// Note that this does *NOT* clear the nested operations.
221223
void clearOperation(MlirOperation op);
222224

223225
/// Clears all operations nested inside the given op using
224226
/// `clearOperation(MlirOperation)`.
225227
void clearOperationsInside(PyOperationBase &op);
226228
void clearOperationsInside(MlirOperation op);
227229

230+
/// Clears the operaiton _and_ all operations inside using
231+
/// `clearOperation(MlirOperation)`.
232+
void clearOperationAndInside(PyOperationBase &op);
233+
228234
/// Gets the count of live modules associated with this context.
229235
/// Used for testing.
230236
size_t getLiveModuleCount();
@@ -246,6 +252,7 @@ class PyMlirContext {
246252

247253
private:
248254
PyMlirContext(MlirContext context);
255+
249256
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
250257
// preserving the relationship that an MlirContext maps to a single
251258
// PyMlirContext wrapper. This could be replaced in the future with an

mlir/test/python/live_operations.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# RUN: %PYTHON %s
2+
# It is sufficient that this doesn't assert.
3+
4+
from mlir.ir import *
5+
6+
7+
def createDetachedModule():
8+
module = Module.create()
9+
with InsertionPoint(module.body):
10+
# TODO: Python bindings are currently unaware that modules are also
11+
# operations, so having a module erased won't trigger the cascading
12+
# removal of live operations (#93337). Use a non-module operation
13+
# instead.
14+
nested = Operation.create("test.some_operation", regions=1)
15+
16+
# When the operation is detached from parent, it is considered to be
17+
# owned by Python. It will therefore be erased when the Python object
18+
# is destroyed.
19+
nested.detach_from_parent()
20+
21+
# However, we create and maintain references to operations within
22+
# `nested`. These references keep the corresponding operations in the
23+
# "live" list even if they have been erased in C++, making them
24+
# "zombie". If the C++ allocator reuses one of the address previously
25+
# used for a now-"zombie" operation, this used to result in an
26+
# assertion "cannot create detached operation that already exists" from
27+
# the bindings code. Erasing the detached operation should result in
28+
# removing all nested operations from the live list.
29+
#
30+
# Note that the assertion is not guaranteed since it depends on the
31+
# behavior of the allocator on the C++ side, so this test mail fail
32+
# intermittently.
33+
with InsertionPoint(nested.regions[0].blocks.append()):
34+
a = [Operation.create("test.some_other_operation") for i in range(100)]
35+
return a
36+
37+
38+
def createManyDetachedModules():
39+
with Context() as ctx, Location.unknown():
40+
ctx.allow_unregistered_dialects = True
41+
for j in range(100):
42+
a = createDetachedModule()
43+
44+
45+
if __name__ == "__main__":
46+
createManyDetachedModules()

0 commit comments

Comments
 (0)