Skip to content

[mlir][Transforms][NFC] GreedyPatternRewriteDriver: Use composition instead of inheritance #92785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

This commit simplifies the design of the GreedyPatternRewriterDriver class. This class used to inherit from both PatternRewriter and RewriterBase::Listener and then attached itself as a listener.

In the new design, the class has a PatternRewriter field instead of inheriting from PatternRewriter, which is generally perferred in object-oriented programming.

… instead of inheritance

This commit simplifies the design of the `GreedyPatternRewriterDriver` class. This class used to inherit from both `PatternRewriter` and `RewriterBase::Listener` and then attached itself as a listener.

In the new design, the class has a `PatternRewriter` field instead of inheriting from `PatternRewriter`, which is generally perferred in object-oriented programming.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 20, 2024
@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit simplifies the design of the GreedyPatternRewriterDriver class. This class used to inherit from both PatternRewriter and RewriterBase::Listener and then attached itself as a listener.

In the new design, the class has a PatternRewriter field instead of inheriting from PatternRewriter, which is generally perferred in object-oriented programming.


Full diff: https://github.com/llvm/llvm-project/pull/92785.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+1)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+22-18)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2562301e499dd..eb8297b65bf60 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
 /// place.
 class PatternRewriter : public RewriterBase {
 public:
+  PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
   using RewriterBase::RewriterBase;
 
   /// A hook used to indicate if the pattern rewriter can recover from failure
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..597cb29ce911b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
 /// This abstract class manages the worklist and contains helper methods for
 /// rewriting ops on the worklist. Derived classes specify how ops are added
 /// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter,
-                                   public RewriterBase::Listener {
+class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 protected:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the specified operation was inserted. Update the
   /// worklist as needed: The operation is enqueued depending on scope and
   /// strict mode.
-  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
 
   /// Notify the driver that the specified operation was removed. Update the
   /// worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// reached. Return `true` if any IR was changed.
   bool processWorklist();
 
+  /// The pattern rewriter that is used for making IR modifications and is
+  /// passed to rewrite patterns.
+  PatternRewriter rewriter;
+
   /// The worklist for this transformation keeps track of the operations that
   /// need to be (re)visited.
 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), config(config), matcher(patterns)
+    : rewriter(ctx), config(config), matcher(patterns)
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       // clang-format off
       , expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // Send IR notifications to the debug handler. This handler will then forward
   // all notifications to this GreedyPatternRewriteDriver.
-  setListener(&expensiveChecks);
+  rewriter.setListener(&expensiveChecks);
 #else
-  setListener(this);
+  rewriter.setListener(this);
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 
     // If the operation is trivially dead - remove it.
     if (isOpTriviallyDead(op)) {
-      eraseOp(op);
+      rewriter.eraseOp(op);
       changed = true;
 
       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         // Op results can be replaced with `foldResults`.
         assert(foldResults.size() == op->getNumResults() &&
                "folder produced incorrect number of results");
-        OpBuilder::InsertionGuard g(*this);
-        setInsertionPoint(op);
+        OpBuilder::InsertionGuard g(rewriter);
+        rewriter.setInsertionPoint(op);
         SmallVector<Value> replacements;
         bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           }
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
-              *this, ofr.get<Attribute>(), resultType, op->getLoc());
+              rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
 
           if (!constOp) {
             // If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
               replacementOps.insert(replacement.getDefiningOp());
             }
             for (Operation *op : replacementOps) {
-              eraseOp(op);
+              rewriter.eraseOp(op);
             }
 
             materializationSucceeded = false;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         }
 
         if (materializationSucceeded) {
-          replaceOp(op, replacements);
+          rewriter.replaceOp(op, replacements);
           changed = true;
           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
     LogicalResult matchResult =
-        matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
+        matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
 
     if (succeeded(matchResult)) {
       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
     config.listener->notifyBlockErased(block);
 }
 
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
-                                                         InsertPoint previous) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
   bool continueRewrites = false;
   int64_t iteration = 0;
-  MLIRContext *ctx = getContext();
+  MLIRContext *ctx = rewriter.getContext();
   do {
     // Check if the iteration limit was reached.
     if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
     // `OperationFolder` CSE's constant ops (and may move them into parents
     // regions to enable more aggressive CSE'ing).
-    OperationFolder folder(getContext(), this);
+    OperationFolder folder(ctx, this);
     auto insertKnownConstant = [&](Operation *op) {
       // Check for existing constants when populating the worklist. This avoids
       // accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
           if (config.enableRegionSimplification)
-            continueRewrites |= succeeded(simplifyRegions(*this, region));
+            continueRewrites |= succeeded(simplifyRegions(rewriter, region));
         },
         {&region}, iteration);
   } while (continueRewrites);

@matthias-springer matthias-springer merged commit 6b3e000 into main Jun 8, 2024
5 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/greedy_pattern_composition branch June 8, 2024 08:26
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Jun 9, 2024
… instead of inheritance (llvm#92785)

This commit simplifies the design of the `GreedyPatternRewriterDriver`
class. This class used to inherit from both `PatternRewriter` and
`RewriterBase::Listener` and then attached itself as a listener.

In the new design, the class has a `PatternRewriter` field instead of
inheriting from `PatternRewriter`, which is generally perferred in
object-oriented programming.

---------

Co-authored-by: Markus Böck <[email protected]>
Signed-off-by: Hafidz Muzakky <[email protected]>
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants