Skip to content

[mlir][python] Clear PyOperations instead of invalidating them. #70044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,12 @@ typedef enum MlirWalkOrder {
} MlirWalkOrder;

/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation a pointer to a `userData`.
/// operation and a pointer to a `userData`.
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);

/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
/// some context or other data into the callback.
/// context or other data into the callback.
MLIR_CAPI_EXPORTED
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder);
Expand Down
29 changes: 26 additions & 3 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,32 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}

void PyMlirContext::setOperationInvalid(MlirOperation op) {
if (liveOperations.contains(op.ptr))
liveOperations[op.ptr].second->setInvalid();
void PyMlirContext::clearOperation(MlirOperation op) {
auto it = liveOperations.find(op.ptr);
if (it != liveOperations.end()) {
it->second.second->setInvalid();
liveOperations.erase(it);
}
}

void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
typedef struct {
PyOperation &rootOp;
bool rootSeen;
} callBackData;
callBackData data{op.getOperation(), false};
// Mark all ops below the op that the passmanager will be rooted
// at (but not op itself - note the preorder) as invalid.
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PyMlirContext;
class DefaultingPyMlirContext;
class PyModule;
class PyOperation;
class PyOperationBase;
class PyType;
class PySymbolTable;
class PyValue;
Expand Down Expand Up @@ -209,10 +210,15 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();

/// Sets an operation invalid. This is useful for when some non-bindings
/// code destroys the operation and the bindings need to made aware. For
/// example, in the case when pass manager is run.
void setOperationInvalid(MlirOperation op);
/// Removes an operation from the live operations map and sets it invalid.
/// This is useful for when some non-bindings code destroys the operation and
/// the bindings need to made aware. For example, in the case when pass
/// manager is run.
void clearOperation(MlirOperation op);

/// Clears all operations nested inside the given op using
/// `clearOperation(MlirOperation)`.
void clearOperationsInside(PyOperationBase &op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
Expand Down
20 changes: 1 addition & 19 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](PyPassManager &passManager, PyOperationBase &op,
bool invalidateOps) {
if (invalidateOps) {
typedef struct {
PyOperation &rootOp;
bool rootSeen;
} callBackData;
callBackData data{op.getOperation(), false};
// Mark all ops below the op that the passmanager will be rooted
// at (but not op itself - note the preorder) as invalid.
MlirOperationWalkCallback invalidatingCallback =
[](MlirOperation op, void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation()
.getContext()
->setOperationInvalid(op);
else
data->rootSeen = true;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
op.getOperation().getContext()->clearOperationsInside(op);
}
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
Expand Down
25 changes: 23 additions & 2 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())

# CHECK: invalidate_ops=False
log("invalidate_ops=False")

# CHECK: live ops: 0
log_op_count()

module = ModuleOp.parse(
"""
module {
Expand All @@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
"""
)

# CHECK: invalidate_ops=False
log("invalidate_ops=False")
# CHECK: live ops: 1
log_op_count()

outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
Expand All @@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)

# CHECK: live ops: 4
log_op_count()

PassManager.parse("builtin.module(canonicalize)").run(
module, invalidate_ops=False
)
Expand All @@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")

# CHECK: live ops: 4
log_op_count()

module = ModuleOp.parse(
"""
module {
Expand All @@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]

# CHECK: live ops: 4
log_op_count()

PassManager.parse("builtin.module(canonicalize)").run(module)

# CHECK: live ops: 1
log_op_count()

try:
log(func_op)
except RuntimeError as e:
Expand Down