diff --git a/include/circt/Reduce/GenericReductions.h b/include/circt/Reduce/GenericReductions.h index d7fb87341a03..0fb9a3908e43 100644 --- a/include/circt/Reduce/GenericReductions.h +++ b/include/circt/Reduce/GenericReductions.h @@ -10,13 +10,19 @@ #define CIRCT_REDUCE_GENERICREDUCTIONS_H #include "circt/Reduce/Reduction.h" +#include namespace circt { /// Populate reduction patterns that are not specific to certain operations or -/// dialects -void populateGenericReducePatterns(MLIRContext *context, - ReducePatternSet &patterns); +/// dialects. +/// +/// The optional `maxNumRewrites` parameter allows callers to override the +/// greedy rewrite budget used by reductions that rely on the canonicalizer +/// pass. +void populateGenericReducePatterns( + MLIRContext *context, ReducePatternSet &patterns, + std::optional maxNumRewrites = std::nullopt); } // namespace circt diff --git a/lib/Reduce/GenericReductions.cpp b/lib/Reduce/GenericReductions.cpp index 5723ecb9e39f..62f99fc1371d 100644 --- a/lib/Reduce/GenericReductions.cpp +++ b/lib/Reduce/GenericReductions.cpp @@ -96,18 +96,23 @@ struct MakeSymbolsPrivate : public Reduction { // Reduction Registration //===----------------------------------------------------------------------===// -static std::unique_ptr createSimpleCanonicalizerPass() { +static std::unique_ptr +createSimpleCanonicalizerPass(std::optional maxNumRewrites) { GreedyRewriteConfig config; config.setUseTopDownTraversal(true); config.setRegionSimplificationLevel( mlir::GreedySimplifyRegionLevel::Disabled); + if (maxNumRewrites) + config.setMaxNumRewrites(*maxNumRewrites); return createCanonicalizerPass(config); } -void circt::populateGenericReducePatterns(MLIRContext *context, - ReducePatternSet &patterns) { +void circt::populateGenericReducePatterns( + MLIRContext *context, ReducePatternSet &patterns, + std::optional maxNumRewrites) { patterns.add(context, createSymbolDCEPass()); - patterns.add(context, createSimpleCanonicalizerPass()); + patterns.add( + context, createSimpleCanonicalizerPass(maxNumRewrites)); patterns.add(context, createCSEPass()); patterns.add(); patterns.add(); diff --git a/tools/circt-reduce/circt-reduce.cpp b/tools/circt-reduce/circt-reduce.cpp index c01b7d577ef5..4e2f1580b016 100644 --- a/tools/circt-reduce/circt-reduce.cpp +++ b/tools/circt-reduce/circt-reduce.cpp @@ -104,6 +104,12 @@ static cl::opt verbose("v", cl::init(true), cl::desc("Print reduction progress to stderr"), cl::cat(mainCategory)); +cl::opt maxNumRewrites( + "max-num-rewrites", cl::init(-1), + cl::desc("Maximum number of rewrites GreedyPatternRewriteDriver may " + "apply (negative value keeps the default)"), + cl::cat(mainCategory)); + static cl::opt maxChunks("max-chunks", cl::init(0), cl::desc("Stop increasing granularity beyond this number of " @@ -193,7 +199,10 @@ static LogicalResult execute(MLIRContext &context) { // Gather a list of reduction patterns that we should try. ReducePatternSet patterns; - populateGenericReducePatterns(&context, patterns); + std::optional maxNumRewritesOpt; + if (maxNumRewrites >= 0) + maxNumRewritesOpt = maxNumRewrites; + populateGenericReducePatterns(&context, patterns, maxNumRewritesOpt); ReducePatternInterfaceCollection reducePatternCollection(&context); reducePatternCollection.populateReducePatterns(patterns); auto reductionFilter = [&](const Reduction &reduction) {