Skip to content

Commit 129659c

Browse files
[mlir][python] Clear PyOperations instead of invalidating them.
`PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. llvm#69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs. This is required to actually solve the issue in llvm#69730.
1 parent c45466c commit 129659c

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,12 @@ size_t PyMlirContext::clearLiveOperations() {
635635
return numInvalidated;
636636
}
637637

638-
void PyMlirContext::setOperationInvalid(MlirOperation op) {
639-
if (liveOperations.contains(op.ptr))
640-
liveOperations[op.ptr].second->setInvalid();
638+
void PyMlirContext::clearOperation(MlirOperation op) {
639+
auto it = liveOperations.find(op.ptr);
640+
if (it != liveOperations.end()) {
641+
it->second.second->setInvalid();
642+
liveOperations.erase(it);
643+
}
641644
}
642645

643646
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ class PyMlirContext {
209209
/// place.
210210
size_t clearLiveOperations();
211211

212-
/// Sets an operation invalid. This is useful for when some non-bindings
213-
/// code destroys the operation and the bindings need to made aware. For
214-
/// example, in the case when pass manager is run.
215-
void setOperationInvalid(MlirOperation op);
212+
/// Removes an operation from the live operations map and sets it invalid.
213+
/// This is useful for when some non-bindings code destroys the operation and
214+
/// the bindings need to made aware. For example, in the case when pass
215+
/// manager is run.
216+
void clearOperation(MlirOperation op);
216217

217218
/// Gets the count of live modules associated with this context.
218219
/// Used for testing.

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
130130
[](MlirOperation op, void *userData) {
131131
callBackData *data = static_cast<callBackData *>(userData);
132132
if (LLVM_LIKELY(data->rootSeen))
133-
data->rootOp.getOperation()
134-
.getContext()
135-
->setOperationInvalid(op);
133+
data->rootOp.getOperation().getContext()->clearOperation(
134+
op);
136135
else
137136
data->rootSeen = true;
138137
};

mlir/test/python/pass_manager.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ def testRunPipelineError():
176176
@run
177177
def testPostPassOpInvalidation():
178178
with Context() as ctx:
179+
log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
180+
181+
# CHECK: invalidate_ops=False
182+
log("invalidate_ops=False")
183+
184+
# CHECK: live ops: 0
185+
log_op_count()
186+
179187
module = ModuleOp.parse(
180188
"""
181189
module {
@@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
188196
"""
189197
)
190198

191-
# CHECK: invalidate_ops=False
192-
log("invalidate_ops=False")
199+
# CHECK: live ops: 1
200+
log_op_count()
193201

194202
outer_const_op = module.body.operations[0]
195203
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
@@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
206214
# CHECK: %[[VAL1]] = arith.constant 10 : i64
207215
log(inner_const_op)
208216

217+
# CHECK: live ops: 4
218+
log_op_count()
219+
209220
PassManager.parse("builtin.module(canonicalize)").run(
210221
module, invalidate_ops=False
211222
)
@@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
222233
# CHECK: invalidate_ops=True
223234
log("invalidate_ops=True")
224235

236+
# CHECK: live ops: 4
237+
log_op_count()
238+
225239
module = ModuleOp.parse(
226240
"""
227241
module {
@@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
237251
func_op = module.body.operations[1]
238252
inner_const_op = func_op.body.blocks[0].operations[0]
239253

254+
# CHECK: live ops: 4
255+
log_op_count()
256+
240257
PassManager.parse("builtin.module(canonicalize)").run(module)
258+
259+
# CHECK: live ops: 1
260+
log_op_count()
261+
241262
try:
242263
log(func_op)
243264
except RuntimeError as e:

0 commit comments

Comments
 (0)