3434#include " mlir/IR/Visitors.h"
3535#include " mlir/Support/LLVM.h"
3636#include " llvm/ADT/APInt.h"
37+ #include " llvm/ADT/Bitset.h"
3738#include " llvm/ADT/DenseMap.h"
3839#include " llvm/ADT/MapVector.h"
3940#include " llvm/ADT/STLExtras.h"
@@ -483,38 +484,72 @@ void CutSet::addCut(Cut cut) {
483484
484485ArrayRef<Cut> CutSet::getCuts () const { return cuts; }
485486
486- void CutSet::finalize (
487- const CutRewriterOptions &options,
488- llvm::function_ref<std::optional<MatchedPattern>(Cut &)> matchCut) {
489- DenseSet<std::pair<ArrayRef<Value>, Operation *>> uniqueCuts;
490- unsigned uniqueCount = 0 ;
487+ // Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
488+ // exists another cut that is a subset of it. We use a bitset to represent the
489+ // inputs of each cut for efficient subset checking.
490+ static void removeDuplicateAndNonMinimalCuts (SmallVectorImpl<Cut> &cuts) {
491+ // First sort the cuts by input size (ascending). This ensures that when we
492+ // iterate through the cuts, we always encounter smaller cuts first, allowing
493+ // us to efficiently check for non-minimality. Stable sort to maintain
494+ // relative order of cuts with the same input size.
495+ std::stable_sort (cuts.begin (), cuts.end (), [](const Cut &a, const Cut &b) {
496+ return a.getInputSize () < b.getInputSize ();
497+ });
498+
499+ llvm::SmallVector<llvm::Bitset<64 >, 4 > inputBitMasks;
500+ DenseMap<Value, unsigned > inputIndices;
501+ auto getIndex = [&](Value v) -> unsigned {
502+ auto it = inputIndices.find (v);
503+ if (it != inputIndices.end ())
504+ return it->second ;
505+ unsigned index = inputIndices.size ();
506+ if (LLVM_UNLIKELY (index >= 64 ))
507+ llvm::report_fatal_error (
508+ " Too many unique inputs across cuts. Max 64 supported. Consider "
509+ " increasing the compile-time constant." );
510+ inputIndices[v] = index;
511+ return index;
512+ };
513+
491514 for (unsigned i = 0 ; i < cuts.size (); ++i) {
492515 auto &cut = cuts[i];
493516 // Create a unique identifier for the cut based on its inputs.
494- auto inputs = cut.inputs .getArrayRef ();
495-
496- // If the cut is a duplicate, skip it.
497- if (uniqueCuts.contains ({inputs, cut.getRoot ()}))
517+ llvm::Bitset<64 > inputsMask;
518+ for (auto input : cut.inputs .getArrayRef ())
519+ inputsMask.set (getIndex (input));
520+
521+ bool isUnique = llvm::all_of (
522+ inputBitMasks, [&](const llvm::Bitset<64 > &existingCutInputMask) {
523+ // If the bitset is a subset of the current inputsMask, it is not
524+ // unique
525+ return (existingCutInputMask & inputsMask) != existingCutInputMask;
526+ });
527+
528+ if (!isUnique)
498529 continue ;
499530
500- if (i != uniqueCount) {
501- // Move the unique cut to the front of the vector
502- // This maintains the order of cuts while removing duplicates
503- // by swapping with the last unique cut found.
504- cuts[uniqueCount] = std::move (cuts[i]);
505- }
506-
507- // Beaware of lifetime of ArrayRef. `cuts[uniqueCount]` is always valid
508- // after this point.
509- uniqueCuts.insert (
510- {cuts[uniqueCount].inputs .getArrayRef (), cuts[uniqueCount].getRoot ()});
511- ++uniqueCount;
531+ // If the cut is unique, keep it
532+ size_t uniqueCount = inputBitMasks.size ();
533+ if (i != uniqueCount)
534+ cuts[uniqueCount] = std::move (cut);
535+ inputBitMasks.push_back (inputsMask);
512536 }
513537
538+ unsigned uniqueCount = inputBitMasks.size ();
539+
514540 LLVM_DEBUG (llvm::dbgs () << " Original cuts: " << cuts.size ()
515541 << " Unique cuts: " << uniqueCount << " \n " );
542+
516543 // Resize the cuts vector to the number of unique cuts found
517544 cuts.resize (uniqueCount);
545+ }
546+
547+ void CutSet::finalize (
548+ const CutRewriterOptions &options,
549+ llvm::function_ref<std::optional<MatchedPattern>(Cut &)> matchCut) {
550+
551+ // First, remove duplicate and non-minimal cuts.
552+ removeDuplicateAndNonMinimalCuts (cuts);
518553
519554 // Maintain size limit by removing worst cuts
520555 if (cuts.size () > options.maxCutSizePerRoot ) {
@@ -523,7 +558,7 @@ void CutSet::finalize(
523558 // TODO: Make this configurable.
524559 // TODO: Implement pruning based on dominance.
525560
526- std::sort (cuts.begin (), cuts.end (), [](const Cut &a, const Cut &b) {
561+ std::stable_sort (cuts.begin (), cuts.end (), [](const Cut &a, const Cut &b) {
527562 if (a.getDepth () == b.getDepth ())
528563 return a.getInputSize () < b.getInputSize ();
529564 return a.getDepth () < b.getDepth ();
0 commit comments