@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
97
97
return cast<DILocalScope>(Scope)->getSubprogram ();
98
98
}
99
99
100
- // / Erase \p V from \p BB and move \II forward to avoid invalidating
101
- // / iterators.
102
- static void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
103
- BasicBlock &BB) {
104
- auto *Inst = cast<Instruction>(V);
105
- // Still used, don't erase.
106
- if (!Inst->use_empty ())
107
- return ;
108
- if (II != BB.rend () && Inst == &*II)
109
- ++II;
110
- Inst->eraseFromParent ();
111
- }
112
-
113
100
// / Return true if V is a splat of a value (which is used when multiplying a
114
101
// / matrix with a scalar).
115
102
static bool isSplat (Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259
246
// / Return the ShapeInfo for the result of \p I, it it can be determined.
260
247
static std::optional<ShapeInfo>
261
248
computeShapeInfoForInst (Instruction *I,
262
- const ValueMap <Value *, ShapeInfo> &ShapeMap) {
249
+ const DenseMap <Value *, ShapeInfo> &ShapeMap) {
263
250
Value *M;
264
251
Value *N;
265
252
Value *K;
@@ -492,10 +479,16 @@ class LowerMatrixIntrinsics {
492
479
// / the result value of the instruction, with the only exceptions being store
493
480
// / instructions and the matrix_column_major_store intrinsics. For those, the
494
481
// / shape information indicates that those instructions should be lowered
495
- // / using shape information as well. A ValueMap is used so that when
496
- // / sub-passes like optimizeTransposes performs RAUW the map stays
497
- // / up-to-date.
498
- ValueMap<Value *, ShapeInfo> ShapeMap;
482
+ // / using shape information as well. Note that extra care is needed when
483
+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
484
+ // / replacement is also a matrix operation, use
485
+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
486
+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
487
+ // / want to add shape information for a replacement instruction. When directly
488
+ // / erasing a value with an entry in ShapeMap, use
489
+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
490
+ // / accordingly.
491
+ DenseMap<Value *, ShapeInfo> ShapeMap;
499
492
500
493
// / List of instructions to remove. While lowering, we are not replacing all
501
494
// / users of a lowered instruction, if shape information is available and
@@ -759,6 +752,30 @@ class LowerMatrixIntrinsics {
759
752
return Operation (T0, Shape0.t (), T1, Shape1.t ());
760
753
}
761
754
755
+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
756
+ // / itself.
757
+ void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
758
+ auto Iter = ShapeMap.find (Inst);
759
+ if (Iter != ShapeMap.end ())
760
+ ShapeMap.erase (Iter);
761
+ Inst->eraseFromParent ();
762
+ }
763
+
764
+ // / Erase \p V from \p BB and move \II forward to avoid invalidating
765
+ // / iterators.
766
+ void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
767
+ BasicBlock &BB) {
768
+ auto *Inst = cast<Instruction>(V);
769
+ // Still used, don't erase.
770
+ if (!Inst->use_empty ())
771
+ return ;
772
+ if (II != BB.rend () && Inst == &*II)
773
+ ++II;
774
+ eraseFromParentAndRemoveFromShapeMap (Inst);
775
+ }
776
+
777
+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
778
+ // / entry for \p Old and replace all uses of \p Old with \p New.
762
779
void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
763
780
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
764
781
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -872,13 +889,13 @@ class LowerMatrixIntrinsics {
872
889
873
890
void liftTranspose (Instruction &I) {
874
891
// Erase dead Instructions after lifting transposes from binops.
875
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
892
+ auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
876
893
if (T.use_empty ())
877
- T. eraseFromParent ( );
894
+ eraseFromParentAndRemoveFromShapeMap (&T );
878
895
if (A->use_empty ())
879
- cast<Instruction>(A)-> eraseFromParent ( );
896
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(A));
880
897
if (A != B && B->use_empty ())
881
- cast<Instruction>(B)-> eraseFromParent ( );
898
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(B));
882
899
};
883
900
884
901
Value *A, *B, *AT, *BT;
@@ -1476,7 +1493,7 @@ class LowerMatrixIntrinsics {
1476
1493
m_Value (Arg)))) {
1477
1494
auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg);
1478
1495
Op->replaceAllUsesWith (NewLoad);
1479
- cast<Instruction>(Op)-> eraseFromParent ( );
1496
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(Op));
1480
1497
return ;
1481
1498
} else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1482
1499
m_Value (Arg)))) {
@@ -1845,15 +1862,15 @@ class LowerMatrixIntrinsics {
1845
1862
// Mark eliminated instructions as fused and remove them.
1846
1863
FusedInsts.insert (Store);
1847
1864
FusedInsts.insert (MatMul);
1848
- Store-> eraseFromParent ( );
1849
- MatMul-> eraseFromParent ( );
1865
+ eraseFromParentAndRemoveFromShapeMap (Store );
1866
+ eraseFromParentAndRemoveFromShapeMap (MatMul );
1850
1867
if (LoadOp0->hasNUses (0 )) {
1851
1868
FusedInsts.insert (LoadOp0);
1852
- LoadOp0-> eraseFromParent ( );
1869
+ eraseFromParentAndRemoveFromShapeMap (LoadOp0 );
1853
1870
}
1854
1871
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses (0 )) {
1855
1872
FusedInsts.insert (LoadOp1);
1856
- LoadOp1-> eraseFromParent ( );
1873
+ eraseFromParentAndRemoveFromShapeMap (LoadOp1 );
1857
1874
}
1858
1875
}
1859
1876
0 commit comments