18
18
#include " mlir/Transforms/Passes.h"
19
19
#include " mlir/Transforms/RegionUtils.h"
20
20
#include " llvm/ADT/STLExtras.h"
21
- #include " llvm/Support/Casting.h"
22
21
#include " llvm/Support/GenericIteratedDominanceFrontier.h"
23
22
24
23
namespace mlir {
@@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
158
157
const DataLayout &dataLayout;
159
158
};
160
159
160
+ using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t >>;
161
+
161
162
// / The MemorySlotPromoter handles the state of promoting a memory slot. It
162
163
// / wraps a slot and its associated allocator. This will perform the mutation of
163
164
// / IR.
@@ -166,7 +167,8 @@ class MemorySlotPromoter {
166
167
MemorySlotPromoter (MemorySlot slot, PromotableAllocationOpInterface allocator,
167
168
OpBuilder &builder, DominanceInfo &dominance,
168
169
const DataLayout &dataLayout, MemorySlotPromotionInfo info,
169
- const Mem2RegStatistics &statistics);
170
+ const Mem2RegStatistics &statistics,
171
+ BlockIndexCache &blockIndexCache);
170
172
171
173
// / Actually promotes the slot by mutating IR. Promoting a slot DOES
172
174
// / invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -207,16 +209,21 @@ class MemorySlotPromoter {
207
209
const DataLayout &dataLayout;
208
210
MemorySlotPromotionInfo info;
209
211
const Mem2RegStatistics &statistics;
212
+
213
+ // / Shared cache of block indices of specific regions.
214
+ BlockIndexCache &blockIndexCache;
210
215
};
211
216
212
217
} // namespace
213
218
214
219
MemorySlotPromoter::MemorySlotPromoter (
215
220
MemorySlot slot, PromotableAllocationOpInterface allocator,
216
221
OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
217
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
222
+ MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
223
+ BlockIndexCache &blockIndexCache)
218
224
: slot(slot), allocator(allocator), builder(builder), dominance(dominance),
219
- dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
225
+ dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
226
+ blockIndexCache(blockIndexCache) {
220
227
#ifndef NDEBUG
221
228
auto isResultOrNewBlockArgument = [&]() {
222
229
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr ))
@@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
500
507
}
501
508
}
502
509
510
+ // / Gets or creates a block index mapping for `region`.
511
+ static const DenseMap<Block *, size_t > &
512
+ getOrCreateBlockIndices (BlockIndexCache &blockIndexCache, Region *region) {
513
+ auto [it, inserted] = blockIndexCache.try_emplace (region);
514
+ if (!inserted)
515
+ return it->second ;
516
+
517
+ DenseMap<Block *, size_t > &blockIndices = it->second ;
518
+ SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (*region);
519
+ for (auto [index , block] : llvm::enumerate (topologicalOrder))
520
+ blockIndices[block] = index ;
521
+ return blockIndices;
522
+ }
523
+
503
524
// / Sorts `ops` according to dominance. Relies on the topological order of basic
504
- // / blocks to get a deterministic ordering.
505
- static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion) {
525
+ // / blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
526
+ // / potentially expensive recomputation of a block index map.
527
+ static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion,
528
+ BlockIndexCache &blockIndexCache) {
506
529
// Produce a topological block order and construct a map to lookup the indices
507
530
// of blocks.
508
- DenseMap<Block *, size_t > topoBlockIndices;
509
- SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (region);
510
- for (auto [index , block] : llvm::enumerate (topologicalOrder))
511
- topoBlockIndices[block] = index ;
531
+ const DenseMap<Block *, size_t > &topoBlockIndices =
532
+ getOrCreateBlockIndices (blockIndexCache, ®ion);
512
533
513
534
// Combining the topological order of the basic blocks together with block
514
535
// internal operation order guarantees a deterministic, dominance respecting
@@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
527
548
llvm::make_first_range (info.userToBlockingUses ));
528
549
529
550
// Sort according to dominance.
530
- dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent ());
551
+ dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent (),
552
+ blockIndexCache);
531
553
532
554
llvm::SmallVector<Operation *> toErase;
533
555
// List of all replaced values in the slot.
@@ -605,20 +627,25 @@ void MemorySlotPromoter::promoteSlot() {
605
627
606
628
LogicalResult mlir::tryToPromoteMemorySlots (
607
629
ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
608
- const DataLayout &dataLayout, Mem2RegStatistics statistics) {
630
+ const DataLayout &dataLayout, DominanceInfo &dominance,
631
+ Mem2RegStatistics statistics) {
609
632
bool promotedAny = false ;
610
633
634
+ // A cache that stores deterministic block indices which are used to determine
635
+ // a valid operation modification order. The block index maps are computed
636
+ // lazily and cached to avoid expensive recomputation.
637
+ BlockIndexCache blockIndexCache;
638
+
611
639
for (PromotableAllocationOpInterface allocator : allocators) {
612
640
for (MemorySlot slot : allocator.getPromotableSlots ()) {
613
641
if (slot.ptr .use_empty ())
614
642
continue ;
615
643
616
- DominanceInfo dominance;
617
644
MemorySlotPromotionAnalyzer analyzer (slot, dominance, dataLayout);
618
645
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo ();
619
646
if (info) {
620
647
MemorySlotPromoter (slot, allocator, builder, dominance, dataLayout,
621
- std::move (*info), statistics)
648
+ std::move (*info), statistics, blockIndexCache )
622
649
.promoteSlot ();
623
650
promotedAny = true ;
624
651
}
@@ -640,6 +667,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
640
667
641
668
bool changed = false ;
642
669
670
+ auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
671
+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
672
+ auto &dominance = getAnalysis<DominanceInfo>();
673
+
643
674
for (Region ®ion : scopeOp->getRegions ()) {
644
675
if (region.getBlocks ().empty ())
645
676
continue ;
@@ -655,16 +686,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
655
686
allocators.emplace_back (allocator);
656
687
});
657
688
658
- auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
659
- const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
660
-
661
689
// Attempt promoting until no promotion succeeds.
662
690
if (failed (tryToPromoteMemorySlots (allocators, builder, dataLayout,
663
- statistics)))
691
+ dominance, statistics)))
664
692
break ;
665
693
666
694
changed = true ;
667
- getAnalysisManager ().invalidate ({});
668
695
}
669
696
}
670
697
if (!changed)
0 commit comments