-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Transforms][NFC] GreedyPatternRewriteDriver
: Use composition instead of inheritance
#92785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] GreedyPatternRewriteDriver
: Use composition instead of inheritance
#92785
Conversation
… instead of inheritance This commit simplifies the design of the `GreedyPatternRewriterDriver` class. This class used to inherit from both `PatternRewriter` and `RewriterBase::Listener` and then attached itself as a listener. In the new design, the class has a `PatternRewriter` field instead of inheriting from `PatternRewriter`, which is generally perferred in object-oriented programming.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit simplifies the design of the In the new design, the class has a Full diff: https://github.com/llvm/llvm-project/pull/92785.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2562301e499dd..eb8297b65bf60 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
/// place.
class PatternRewriter : public RewriterBase {
public:
+ PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
using RewriterBase::RewriterBase;
/// A hook used to indicate if the pattern rewriter can recover from failure
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..597cb29ce911b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter,
- public RewriterBase::Listener {
+class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
- void notifyOperationInserted(Operation *op, InsertPoint previous) override;
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// reached. Return `true` if any IR was changed.
bool processWorklist();
+ /// The pattern rewriter that is used for making IR modifications and is
+ /// passed to rewrite patterns.
+ PatternRewriter rewriter;
+
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), config(config), matcher(patterns)
+ : rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
- setListener(&expensiveChecks);
+ rewriter.setListener(&expensiveChecks);
#else
- setListener(this);
+ rewriter.setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
- eraseOp(op);
+ rewriter.eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
- OpBuilder::InsertionGuard g(*this);
- setInsertionPoint(op);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
- *this, ofr.get<Attribute>(), resultType, op->getLoc());
+ rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
if (!constOp) {
// If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
- eraseOp(op);
+ rewriter.eraseOp(op);
}
materializationSucceeded = false;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
if (materializationSucceeded) {
- replaceOp(op, replacements);
+ rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
- matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
+ matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
config.listener->notifyBlockErased(block);
}
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
- InsertPoint previous) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
bool continueRewrites = false;
int64_t iteration = 0;
- MLIRContext *ctx = getContext();
+ MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
- OperationFolder folder(getContext(), this);
+ OperationFolder folder(ctx, this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
- continueRewrites |= succeeded(simplifyRegions(*this, region));
+ continueRewrites |= succeeded(simplifyRegions(rewriter, region));
},
{®ion}, iteration);
} while (continueRewrites);
|
Co-authored-by: Markus Böck <[email protected]>
… instead of inheritance (llvm#92785) This commit simplifies the design of the `GreedyPatternRewriterDriver` class. This class used to inherit from both `PatternRewriter` and `RewriterBase::Listener` and then attached itself as a listener. In the new design, the class has a `PatternRewriter` field instead of inheriting from `PatternRewriter`, which is generally perferred in object-oriented programming. --------- Co-authored-by: Markus Böck <[email protected]> Signed-off-by: Hafidz Muzakky <[email protected]>
This commit simplifies the design of the
GreedyPatternRewriterDriver
class. This class used to inherit from bothPatternRewriter
andRewriterBase::Listener
and then attached itself as a listener.In the new design, the class has a
PatternRewriter
field instead of inheriting fromPatternRewriter
, which is generally perferred in object-oriented programming.