Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions include/circt/Reduce/GenericReductions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
#define CIRCT_REDUCE_GENERICREDUCTIONS_H

#include "circt/Reduce/Reduction.h"
#include <optional>

namespace circt {

/// Populate reduction patterns that are not specific to certain operations or
/// dialects
void populateGenericReducePatterns(MLIRContext *context,
ReducePatternSet &patterns);
/// dialects.
///
/// The optional `maxNumRewrites` parameter allows callers to override the
/// greedy rewrite budget used by reductions that rely on the canonicalizer
/// pass.
void populateGenericReducePatterns(
MLIRContext *context, ReducePatternSet &patterns,
std::optional<int64_t> maxNumRewrites = std::nullopt);

} // namespace circt

Expand Down
13 changes: 9 additions & 4 deletions lib/Reduce/GenericReductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,23 @@ struct MakeSymbolsPrivate : public Reduction {
// Reduction Registration
//===----------------------------------------------------------------------===//

static std::unique_ptr<Pass> createSimpleCanonicalizerPass() {
static std::unique_ptr<Pass>
createSimpleCanonicalizerPass(std::optional<int64_t> maxNumRewrites) {
GreedyRewriteConfig config;
config.setUseTopDownTraversal(true);
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
if (maxNumRewrites)
config.setMaxNumRewrites(*maxNumRewrites);
return createCanonicalizerPass(config);
}

void circt::populateGenericReducePatterns(MLIRContext *context,
ReducePatternSet &patterns) {
void circt::populateGenericReducePatterns(
MLIRContext *context, ReducePatternSet &patterns,
std::optional<int64_t> maxNumRewrites) {
patterns.add<PassReduction, 103>(context, createSymbolDCEPass());
patterns.add<PassReduction, 102>(context, createSimpleCanonicalizerPass());
patterns.add<PassReduction, 102>(
context, createSimpleCanonicalizerPass(maxNumRewrites));
patterns.add<PassReduction, 101>(context, createCSEPass());
patterns.add<MakeSymbolsPrivate, 100>();
patterns.add<UnusedSymbolPruner, 99>();
Expand Down
11 changes: 10 additions & 1 deletion tools/circt-reduce/circt-reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ static cl::opt<bool> verbose("v", cl::init(true),
cl::desc("Print reduction progress to stderr"),
cl::cat(mainCategory));

cl::opt<int64_t> maxNumRewrites(
"max-num-rewrites", cl::init(-1),
cl::desc("Maximum number of rewrites GreedyPatternRewriteDriver may "
"apply (negative value keeps the default)"),
cl::cat(mainCategory));

static cl::opt<unsigned>
maxChunks("max-chunks", cl::init(0),
cl::desc("Stop increasing granularity beyond this number of "
Expand Down Expand Up @@ -193,7 +199,10 @@ static LogicalResult execute(MLIRContext &context) {

// Gather a list of reduction patterns that we should try.
ReducePatternSet patterns;
populateGenericReducePatterns(&context, patterns);
std::optional<int64_t> maxNumRewritesOpt;
if (maxNumRewrites >= 0)
maxNumRewritesOpt = maxNumRewrites;
populateGenericReducePatterns(&context, patterns, maxNumRewritesOpt);
ReducePatternInterfaceCollection reducePatternCollection(&context);
reducePatternCollection.populateReducePatterns(patterns);
auto reductionFilter = [&](const Reduction &reduction) {
Expand Down