diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7b121d4df3286..5659230a03d8c 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -705,12 +705,12 @@ typedef enum MlirWalkOrder { } MlirWalkOrder; /// Operation walker type. The handler is passed an (opaque) reference to an -/// operation a pointer to a `userData`. +/// operation and a pointer to a `userData`. typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); /// Walks operation `op` in `walkOrder` and calls `callback` on that operation. /// `*userData` is passed to the callback as well and can be used to tunnel some -/// some context or other data into the callback. +/// context or other data into the callback. MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index a8ea1a381edb9..7cfea31dbb2e8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -635,9 +635,32 @@ size_t PyMlirContext::clearLiveOperations() { return numInvalidated; } -void PyMlirContext::setOperationInvalid(MlirOperation op) { - if (liveOperations.contains(op.ptr)) - liveOperations[op.ptr].second->setInvalid(); +void PyMlirContext::clearOperation(MlirOperation op) { + auto it = liveOperations.find(op.ptr); + if (it != liveOperations.end()) { + it->second.second->setInvalid(); + liveOperations.erase(it); + } +} + +void PyMlirContext::clearOperationsInside(PyOperationBase &op) { + typedef struct { + PyOperation &rootOp; + bool rootSeen; + } callBackData; + callBackData data{op.getOperation(), false}; + // Mark all ops below the op that the passmanager will be rooted + // at (but not op itself - note the preorder) as invalid. + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + callBackData *data = static_cast(userData); + if (LLVM_LIKELY(data->rootSeen)) + data->rootOp.getOperation().getContext()->clearOperation(op); + else + data->rootSeen = true; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + static_cast(&data), MlirWalkPreOrder); } size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 26292885711a4..01ee4975d0e9a 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -37,6 +37,7 @@ class PyMlirContext; class DefaultingPyMlirContext; class PyModule; class PyOperation; +class PyOperationBase; class PyType; class PySymbolTable; class PyValue; @@ -209,10 +210,15 @@ class PyMlirContext { /// place. size_t clearLiveOperations(); - /// Sets an operation invalid. This is useful for when some non-bindings - /// code destroys the operation and the bindings need to made aware. For - /// example, in the case when pass manager is run. - void setOperationInvalid(MlirOperation op); + /// Removes an operation from the live operations map and sets it invalid. + /// This is useful for when some non-bindings code destroys the operation and + /// the bindings need to made aware. For example, in the case when pass + /// manager is run. + void clearOperation(MlirOperation op); + + /// Clears all operations nested inside the given op using + /// `clearOperation(MlirOperation)`. + void clearOperationsInside(PyOperationBase &op); /// Gets the count of live modules associated with this context. /// Used for testing. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 2175cea79960c..588a8e25414c6 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -119,25 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager, PyOperationBase &op, bool invalidateOps) { if (invalidateOps) { - typedef struct { - PyOperation &rootOp; - bool rootSeen; - } callBackData; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = - [](MlirOperation op, void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation() - .getContext() - ->setOperationInvalid(op); - else - data->rootSeen = true; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); + op.getOperation().getContext()->clearOperationsInside(op); } // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index e7f79ddc75113..0face028b73ff 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -176,6 +176,14 @@ def testRunPipelineError(): @run def testPostPassOpInvalidation(): with Context() as ctx: + log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) + + # CHECK: invalidate_ops=False + log("invalidate_ops=False") + + # CHECK: live ops: 0 + log_op_count() + module = ModuleOp.parse( """ module { @@ -188,8 +196,8 @@ def testPostPassOpInvalidation(): """ ) - # CHECK: invalidate_ops=False - log("invalidate_ops=False") + # CHECK: live ops: 1 + log_op_count() outer_const_op = module.body.operations[0] # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 @@ -206,6 +214,9 @@ def testPostPassOpInvalidation(): # CHECK: %[[VAL1]] = arith.constant 10 : i64 log(inner_const_op) + # CHECK: live ops: 4 + log_op_count() + PassManager.parse("builtin.module(canonicalize)").run( module, invalidate_ops=False ) @@ -222,6 +233,9 @@ def testPostPassOpInvalidation(): # CHECK: invalidate_ops=True log("invalidate_ops=True") + # CHECK: live ops: 4 + log_op_count() + module = ModuleOp.parse( """ module { @@ -237,7 +251,14 @@ def testPostPassOpInvalidation(): func_op = module.body.operations[1] inner_const_op = func_op.body.blocks[0].operations[0] + # CHECK: live ops: 4 + log_op_count() + PassManager.parse("builtin.module(canonicalize)").run(module) + + # CHECK: live ops: 1 + log_op_count() + try: log(func_op) except RuntimeError as e: