Skip to content

Commit 9ae50ae

Browse files
committed
[Matrix] Use DenseMap for ShapeMap instead of ValueMap. (llvm#118282)
ValueMap automatically updates entries with the new value if they have been RAUW. This can lead to instructions that are expected to not have shape info to be added to the map (e.g. shufflevector as in the added test case). This leads to incorrect results. Originally it was used for transpose optimizations, but they now all use updateShapeAndReplaceAllUsesWith, which takes care of updating the shape info as needed. This fixes a crash in the newly added test cases. PR: llvm#118282
1 parent 52ccefb commit 9ae50ae

File tree

3 files changed

+94
-27
lines changed

3 files changed

+94
-27
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797
return cast<DILocalScope>(Scope)->getSubprogram();
9898
}
9999

100-
/// Erase \p V from \p BB and move \II forward to avoid invalidating
101-
/// iterators.
102-
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
103-
BasicBlock &BB) {
104-
auto *Inst = cast<Instruction>(V);
105-
// Still used, don't erase.
106-
if (!Inst->use_empty())
107-
return;
108-
if (II != BB.rend() && Inst == &*II)
109-
++II;
110-
Inst->eraseFromParent();
111-
}
112-
113100
/// Return true if V is a splat of a value (which is used when multiplying a
114101
/// matrix with a scalar).
115102
static bool isSplat(Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259246
/// Return the ShapeInfo for the result of \p I, it it can be determined.
260247
static std::optional<ShapeInfo>
261248
computeShapeInfoForInst(Instruction *I,
262-
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
249+
const DenseMap<Value *, ShapeInfo> &ShapeMap) {
263250
Value *M;
264251
Value *N;
265252
Value *K;
@@ -492,10 +479,16 @@ class LowerMatrixIntrinsics {
492479
/// the result value of the instruction, with the only exceptions being store
493480
/// instructions and the matrix_column_major_store intrinsics. For those, the
494481
/// shape information indicates that those instructions should be lowered
495-
/// using shape information as well. A ValueMap is used so that when
496-
/// sub-passes like optimizeTransposes performs RAUW the map stays
497-
/// up-to-date.
498-
ValueMap<Value *, ShapeInfo> ShapeMap;
482+
/// using shape information as well. Note that extra care is needed when
483+
/// erasing or RAUW'ing a value that is present in ShapeMap. If the
484+
/// replacement is also a matrix operation, use
485+
/// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
486+
/// ShapeMap. We don't use ValueMap, as there are also cases where we do not
487+
/// want to add shape information for a replacement instruction. When directly
488+
/// erasing a value with an entry in ShapeMap, use
489+
/// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
490+
/// accordingly.
491+
DenseMap<Value *, ShapeInfo> ShapeMap;
499492

500493
/// List of instructions to remove. While lowering, we are not replacing all
501494
/// users of a lowered instruction, if shape information is available and
@@ -759,6 +752,30 @@ class LowerMatrixIntrinsics {
759752
return Operation(T0, Shape0.t(), T1, Shape1.t());
760753
}
761754

755+
/// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
756+
/// itself.
757+
void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
758+
auto Iter = ShapeMap.find(Inst);
759+
if (Iter != ShapeMap.end())
760+
ShapeMap.erase(Iter);
761+
Inst->eraseFromParent();
762+
}
763+
764+
/// Erase \p V from \p BB and move \II forward to avoid invalidating
765+
/// iterators.
766+
void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
767+
BasicBlock &BB) {
768+
auto *Inst = cast<Instruction>(V);
769+
// Still used, don't erase.
770+
if (!Inst->use_empty())
771+
return;
772+
if (II != BB.rend() && Inst == &*II)
773+
++II;
774+
eraseFromParentAndRemoveFromShapeMap(Inst);
775+
}
776+
777+
/// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
778+
/// entry for \p Old and replace all uses of \p Old with \p New.
762779
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
763780
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
764781
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -872,13 +889,13 @@ class LowerMatrixIntrinsics {
872889

873890
void liftTranspose(Instruction &I) {
874891
// Erase dead Instructions after lifting transposes from binops.
875-
auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
892+
auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
876893
if (T.use_empty())
877-
T.eraseFromParent();
894+
eraseFromParentAndRemoveFromShapeMap(&T);
878895
if (A->use_empty())
879-
cast<Instruction>(A)->eraseFromParent();
896+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
880897
if (A != B && B->use_empty())
881-
cast<Instruction>(B)->eraseFromParent();
898+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
882899
};
883900

884901
Value *A, *B, *AT, *BT;
@@ -1476,7 +1493,7 @@ class LowerMatrixIntrinsics {
14761493
m_Value(Arg)))) {
14771494
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
14781495
Op->replaceAllUsesWith(NewLoad);
1479-
cast<Instruction>(Op)->eraseFromParent();
1496+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
14801497
return;
14811498
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14821499
m_Value(Arg)))) {
@@ -1845,15 +1862,15 @@ class LowerMatrixIntrinsics {
18451862
// Mark eliminated instructions as fused and remove them.
18461863
FusedInsts.insert(Store);
18471864
FusedInsts.insert(MatMul);
1848-
Store->eraseFromParent();
1849-
MatMul->eraseFromParent();
1865+
eraseFromParentAndRemoveFromShapeMap(Store);
1866+
eraseFromParentAndRemoveFromShapeMap(MatMul);
18501867
if (LoadOp0->hasNUses(0)) {
18511868
FusedInsts.insert(LoadOp0);
1852-
LoadOp0->eraseFromParent();
1869+
eraseFromParentAndRemoveFromShapeMap(LoadOp0);
18531870
}
18541871
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
18551872
FusedInsts.insert(LoadOp1);
1856-
LoadOp1->eraseFromParent();
1873+
eraseFromParentAndRemoveFromShapeMap(LoadOp1);
18571874
}
18581875
}
18591876

llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
190190
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
191191

192192
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
193+
194+
define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
195+
; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
196+
; CHECK-NEXT: entry:
197+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
198+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
199+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
200+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
201+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
202+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
203+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
204+
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
205+
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
206+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
207+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
208+
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
209+
; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
210+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
211+
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
212+
; CHECK-NEXT: ret <1 x i32> [[TMP11]]
213+
;
214+
entry:
215+
%t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
216+
%shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
217+
%t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
218+
%m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
219+
ret <1 x i32> %m
220+
}
221+
222+
declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,28 @@ entry:
144144
ret <6 x double> %mul
145145
}
146146

147+
define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
148+
; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
149+
; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
150+
; CHECK-NEXT: [[ENTRY:.*:]]
151+
; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
152+
; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
153+
; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
154+
; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
155+
; CHECK-NEXT: ret void
156+
;
157+
entry:
158+
%m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
159+
%add = fadd <6 x float> %c, %m
160+
%t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
161+
store <6 x float> %t, ptr %dst, align 4
162+
ret void
163+
}
164+
147165
declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
148166
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
149167
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
150168
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
151169
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
170+
declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
171+
declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)