Skip to content

[mlir][Transforms][NFC] GreedyPatternRewriteDriver: Use composition instead of inheritance #92785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
/// place.
class PatternRewriter : public RewriterBase {
public:
explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
using RewriterBase::RewriterBase;

/// A hook used to indicate if the pattern rewriter can recover from failure
Expand Down
40 changes: 22 additions & 18 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public PatternRewriter,
public RewriterBase::Listener {
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
Expand All @@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;

/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
Expand All @@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// reached. Return `true` if any IR was changed.
bool processWorklist();

/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
PatternRewriter rewriter;

/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
Expand Down Expand Up @@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), config(config), matcher(patterns)
: rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
Expand All @@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
setListener(&expensiveChecks);
rewriter.setListener(&expensiveChecks);
#else
setListener(this);
rewriter.setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

Expand Down Expand Up @@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {

// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
eraseOp(op);
rewriter.eraseOp(op);
changed = true;

LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
Expand Down Expand Up @@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
Expand All @@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
rewriter, ofr.get<Attribute>(), resultType, op->getLoc());

if (!constOp) {
// If materialization fails, cleanup any operations generated for
Expand All @@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
eraseOp(op);
rewriter.eraseOp(op);
}

materializationSucceeded = false;
Expand All @@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}

if (materializationSucceeded) {
replaceOp(op, replacements);
rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Expand Down Expand Up @@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);

if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
Expand Down Expand Up @@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
config.listener->notifyBlockErased(block);
}

void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
InsertPoint previous) {
void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
Expand Down Expand Up @@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
Expand All @@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {

// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
OperationFolder folder(getContext(), this);
OperationFolder folder(ctx, this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Expand Down Expand Up @@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
continueRewrites |= succeeded(simplifyRegions(*this, region));
continueRewrites |= succeeded(simplifyRegions(rewriter, region));
},
{&region}, iteration);
} while (continueRewrites);
Expand Down
Loading