Skip to content

Commit c6efcc9

Browse files
authored
[MLIR][Mem2Reg] Improve performance by avoiding recomputations (#91444)
This commit ensures that Mem2Reg reuses the `DominanceInfo` as well as block index maps to avoid expensive recomputations. Due to the recent migration to `OpBuilder`, the promotion of a slot does no longer replace blocks. Having stable blocks makes the `DominanceInfo` preservable and additionally allows to cache block index maps between different promotions. Performance measurements on very large functions show an up to 4x speedup by these changes.
1 parent 9c09b08 commit c6efcc9

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

mlir/include/mlir/Transforms/Mem2Reg.h

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct Mem2RegStatistics {
2828
LogicalResult
2929
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
3030
OpBuilder &builder, const DataLayout &dataLayout,
31+
DominanceInfo &dominance,
3132
Mem2RegStatistics statistics = {});
3233

3334
} // namespace mlir

mlir/lib/Transforms/Mem2Reg.cpp

+46-19
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "mlir/Transforms/Passes.h"
1919
#include "mlir/Transforms/RegionUtils.h"
2020
#include "llvm/ADT/STLExtras.h"
21-
#include "llvm/Support/Casting.h"
2221
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
2322

2423
namespace mlir {
@@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
158157
const DataLayout &dataLayout;
159158
};
160159

160+
using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
161+
161162
/// The MemorySlotPromoter handles the state of promoting a memory slot. It
162163
/// wraps a slot and its associated allocator. This will perform the mutation of
163164
/// IR.
@@ -166,7 +167,8 @@ class MemorySlotPromoter {
166167
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
167168
OpBuilder &builder, DominanceInfo &dominance,
168169
const DataLayout &dataLayout, MemorySlotPromotionInfo info,
169-
const Mem2RegStatistics &statistics);
170+
const Mem2RegStatistics &statistics,
171+
BlockIndexCache &blockIndexCache);
170172

171173
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
172174
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -207,16 +209,21 @@ class MemorySlotPromoter {
207209
const DataLayout &dataLayout;
208210
MemorySlotPromotionInfo info;
209211
const Mem2RegStatistics &statistics;
212+
213+
/// Shared cache of block indices of specific regions.
214+
BlockIndexCache &blockIndexCache;
210215
};
211216

212217
} // namespace
213218

214219
MemorySlotPromoter::MemorySlotPromoter(
215220
MemorySlot slot, PromotableAllocationOpInterface allocator,
216221
OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
217-
MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
222+
MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
223+
BlockIndexCache &blockIndexCache)
218224
: slot(slot), allocator(allocator), builder(builder), dominance(dominance),
219-
dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
225+
dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
226+
blockIndexCache(blockIndexCache) {
220227
#ifndef NDEBUG
221228
auto isResultOrNewBlockArgument = [&]() {
222229
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
500507
}
501508
}
502509

510+
/// Gets or creates a block index mapping for `region`.
511+
static const DenseMap<Block *, size_t> &
512+
getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
513+
auto [it, inserted] = blockIndexCache.try_emplace(region);
514+
if (!inserted)
515+
return it->second;
516+
517+
DenseMap<Block *, size_t> &blockIndices = it->second;
518+
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(*region);
519+
for (auto [index, block] : llvm::enumerate(topologicalOrder))
520+
blockIndices[block] = index;
521+
return blockIndices;
522+
}
523+
503524
/// Sorts `ops` according to dominance. Relies on the topological order of basic
504-
/// blocks to get a deterministic ordering.
505-
static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
525+
/// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
526+
/// potentially expensive recomputation of a block index map.
527+
static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
528+
BlockIndexCache &blockIndexCache) {
506529
// Produce a topological block order and construct a map to lookup the indices
507530
// of blocks.
508-
DenseMap<Block *, size_t> topoBlockIndices;
509-
SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
510-
for (auto [index, block] : llvm::enumerate(topologicalOrder))
511-
topoBlockIndices[block] = index;
531+
const DenseMap<Block *, size_t> &topoBlockIndices =
532+
getOrCreateBlockIndices(blockIndexCache, &region);
512533

513534
// Combining the topological order of the basic blocks together with block
514535
// internal operation order guarantees a deterministic, dominance respecting
@@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
527548
llvm::make_first_range(info.userToBlockingUses));
528549

529550
// Sort according to dominance.
530-
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
551+
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
552+
blockIndexCache);
531553

532554
llvm::SmallVector<Operation *> toErase;
533555
// List of all replaced values in the slot.
@@ -605,20 +627,25 @@ void MemorySlotPromoter::promoteSlot() {
605627

606628
LogicalResult mlir::tryToPromoteMemorySlots(
607629
ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
608-
const DataLayout &dataLayout, Mem2RegStatistics statistics) {
630+
const DataLayout &dataLayout, DominanceInfo &dominance,
631+
Mem2RegStatistics statistics) {
609632
bool promotedAny = false;
610633

634+
// A cache that stores deterministic block indices which are used to determine
635+
// a valid operation modification order. The block index maps are computed
636+
// lazily and cached to avoid expensive recomputation.
637+
BlockIndexCache blockIndexCache;
638+
611639
for (PromotableAllocationOpInterface allocator : allocators) {
612640
for (MemorySlot slot : allocator.getPromotableSlots()) {
613641
if (slot.ptr.use_empty())
614642
continue;
615643

616-
DominanceInfo dominance;
617644
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
618645
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
619646
if (info) {
620647
MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
621-
std::move(*info), statistics)
648+
std::move(*info), statistics, blockIndexCache)
622649
.promoteSlot();
623650
promotedAny = true;
624651
}
@@ -640,6 +667,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
640667

641668
bool changed = false;
642669

670+
auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
671+
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
672+
auto &dominance = getAnalysis<DominanceInfo>();
673+
643674
for (Region &region : scopeOp->getRegions()) {
644675
if (region.getBlocks().empty())
645676
continue;
@@ -655,16 +686,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
655686
allocators.emplace_back(allocator);
656687
});
657688

658-
auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
659-
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
660-
661689
// Attempt promoting until no promotion succeeds.
662690
if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
663-
statistics)))
691+
dominance, statistics)))
664692
break;
665693

666694
changed = true;
667-
getAnalysisManager().invalidate({});
668695
}
669696
}
670697
if (!changed)

0 commit comments

Comments
 (0)