Skip to content

Commit fa19ef7

Browse files
[mlir][python] Clear PyOperations instead of invalidating them. (#70044)
`PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. #69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs. This is required to actually solve the issue in #69730.
1 parent 9abf3df commit fa19ef7

File tree

5 files changed

+62
-30
lines changed

5 files changed

+62
-30
lines changed

mlir/include/mlir-c/IR.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -705,12 +705,12 @@ typedef enum MlirWalkOrder {
705705
} MlirWalkOrder;
706706

707707
/// Operation walker type. The handler is passed an (opaque) reference to an
708-
/// operation a pointer to a `userData`.
708+
/// operation and a pointer to a `userData`.
709709
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
710710

711711
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
712712
/// `*userData` is passed to the callback as well and can be used to tunnel some
713-
/// some context or other data into the callback.
713+
/// context or other data into the callback.
714714
MLIR_CAPI_EXPORTED
715715
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
716716
void *userData, MlirWalkOrder walkOrder);

mlir/lib/Bindings/Python/IRCore.cpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,32 @@ size_t PyMlirContext::clearLiveOperations() {
635635
return numInvalidated;
636636
}
637637

638-
void PyMlirContext::setOperationInvalid(MlirOperation op) {
639-
if (liveOperations.contains(op.ptr))
640-
liveOperations[op.ptr].second->setInvalid();
638+
void PyMlirContext::clearOperation(MlirOperation op) {
639+
auto it = liveOperations.find(op.ptr);
640+
if (it != liveOperations.end()) {
641+
it->second.second->setInvalid();
642+
liveOperations.erase(it);
643+
}
644+
}
645+
646+
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
647+
typedef struct {
648+
PyOperation &rootOp;
649+
bool rootSeen;
650+
} callBackData;
651+
callBackData data{op.getOperation(), false};
652+
// Mark all ops below the op that the passmanager will be rooted
653+
// at (but not op itself - note the preorder) as invalid.
654+
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
655+
void *userData) {
656+
callBackData *data = static_cast<callBackData *>(userData);
657+
if (LLVM_LIKELY(data->rootSeen))
658+
data->rootOp.getOperation().getContext()->clearOperation(op);
659+
else
660+
data->rootSeen = true;
661+
};
662+
mlirOperationWalk(op.getOperation(), invalidatingCallback,
663+
static_cast<void *>(&data), MlirWalkPreOrder);
641664
}
642665

643666
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

mlir/lib/Bindings/Python/IRModule.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class PyMlirContext;
3737
class DefaultingPyMlirContext;
3838
class PyModule;
3939
class PyOperation;
40+
class PyOperationBase;
4041
class PyType;
4142
class PySymbolTable;
4243
class PyValue;
@@ -209,10 +210,15 @@ class PyMlirContext {
209210
/// place.
210211
size_t clearLiveOperations();
211212

212-
/// Sets an operation invalid. This is useful for when some non-bindings
213-
/// code destroys the operation and the bindings need to made aware. For
214-
/// example, in the case when pass manager is run.
215-
void setOperationInvalid(MlirOperation op);
213+
/// Removes an operation from the live operations map and sets it invalid.
214+
/// This is useful for when some non-bindings code destroys the operation and
215+
/// the bindings need to made aware. For example, in the case when pass
216+
/// manager is run.
217+
void clearOperation(MlirOperation op);
218+
219+
/// Clears all operations nested inside the given op using
220+
/// `clearOperation(MlirOperation)`.
221+
void clearOperationsInside(PyOperationBase &op);
216222

217223
/// Gets the count of live modules associated with this context.
218224
/// Used for testing.

mlir/lib/Bindings/Python/Pass.cpp

+1-19
Original file line numberDiff line numberDiff line change
@@ -119,25 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
119119
[](PyPassManager &passManager, PyOperationBase &op,
120120
bool invalidateOps) {
121121
if (invalidateOps) {
122-
typedef struct {
123-
PyOperation &rootOp;
124-
bool rootSeen;
125-
} callBackData;
126-
callBackData data{op.getOperation(), false};
127-
// Mark all ops below the op that the passmanager will be rooted
128-
// at (but not op itself - note the preorder) as invalid.
129-
MlirOperationWalkCallback invalidatingCallback =
130-
[](MlirOperation op, void *userData) {
131-
callBackData *data = static_cast<callBackData *>(userData);
132-
if (LLVM_LIKELY(data->rootSeen))
133-
data->rootOp.getOperation()
134-
.getContext()
135-
->setOperationInvalid(op);
136-
else
137-
data->rootSeen = true;
138-
};
139-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
140-
static_cast<void *>(&data), MlirWalkPreOrder);
122+
op.getOperation().getContext()->clearOperationsInside(op);
141123
}
142124
// Actually run the pass manager.
143125
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());

mlir/test/python/pass_manager.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ def testRunPipelineError():
176176
@run
177177
def testPostPassOpInvalidation():
178178
with Context() as ctx:
179+
log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
180+
181+
# CHECK: invalidate_ops=False
182+
log("invalidate_ops=False")
183+
184+
# CHECK: live ops: 0
185+
log_op_count()
186+
179187
module = ModuleOp.parse(
180188
"""
181189
module {
@@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
188196
"""
189197
)
190198

191-
# CHECK: invalidate_ops=False
192-
log("invalidate_ops=False")
199+
# CHECK: live ops: 1
200+
log_op_count()
193201

194202
outer_const_op = module.body.operations[0]
195203
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
@@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
206214
# CHECK: %[[VAL1]] = arith.constant 10 : i64
207215
log(inner_const_op)
208216

217+
# CHECK: live ops: 4
218+
log_op_count()
219+
209220
PassManager.parse("builtin.module(canonicalize)").run(
210221
module, invalidate_ops=False
211222
)
@@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
222233
# CHECK: invalidate_ops=True
223234
log("invalidate_ops=True")
224235

236+
# CHECK: live ops: 4
237+
log_op_count()
238+
225239
module = ModuleOp.parse(
226240
"""
227241
module {
@@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
237251
func_op = module.body.operations[1]
238252
inner_const_op = func_op.body.blocks[0].operations[0]
239253

254+
# CHECK: live ops: 4
255+
log_op_count()
256+
240257
PassManager.parse("builtin.module(canonicalize)").run(module)
258+
259+
# CHECK: live ops: 1
260+
log_op_count()
261+
241262
try:
242263
log(func_op)
243264
except RuntimeError as e:

0 commit comments

Comments
 (0)