diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 6a9ec48864b2c..29844c4630751 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) { return cast(Scope)->getSubprogram(); } -/// Erase \p V from \p BB and move \II forward to avoid invalidating -/// iterators. -static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, - BasicBlock &BB) { - auto *Inst = cast(V); - // Still used, don't erase. - if (!Inst->use_empty()) - return; - if (II != BB.rend() && Inst == &*II) - ++II; - Inst->eraseFromParent(); -} - /// Return true if V is a splat of a value (which is used when multiplying a /// matrix with a scalar). static bool isSplat(Value *V) { @@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) { /// Return the ShapeInfo for the result of \p I, it it can be determined. static std::optional computeShapeInfoForInst(Instruction *I, - const ValueMap &ShapeMap) { + const DenseMap &ShapeMap) { Value *M; Value *N; Value *K; @@ -493,10 +480,16 @@ class LowerMatrixIntrinsics { /// the result value of the instruction, with the only exceptions being store /// instructions and the matrix_column_major_store intrinsics. For those, the /// shape information indicates that those instructions should be lowered - /// using shape information as well. A ValueMap is used so that when - /// sub-passes like optimizeTransposes performs RAUW the map stays - /// up-to-date. - ValueMap ShapeMap; + /// using shape information as well. Note that extra care is needed when + /// erasing or RAUW'ing a value that is present in ShapeMap. If the + /// replacement is also a matrix operation, use + /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to + /// ShapeMap. We don't use ValueMap, as there are also cases where we do not + /// want to add shape information for a replacement instruction. When directly + /// erasing a value with an entry in ShapeMap, use + /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated + /// accordingly. + DenseMap ShapeMap; /// List of instructions to remove. While lowering, we are not replacing all /// users of a lowered instruction, if shape information is available and @@ -758,6 +751,30 @@ class LowerMatrixIntrinsics { return Operation(T0, Shape0.t(), T1, Shape1.t()); } + /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst + /// itself. + void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) { + auto Iter = ShapeMap.find(Inst); + if (Iter != ShapeMap.end()) + ShapeMap.erase(Iter); + Inst->eraseFromParent(); + } + + /// Erase \p V from \p BB and move \II forward to avoid invalidating + /// iterators. + void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, + BasicBlock &BB) { + auto *Inst = cast(V); + // Still used, don't erase. + if (!Inst->use_empty()) + return; + if (II != BB.rend() && Inst == &*II) + ++II; + eraseFromParentAndRemoveFromShapeMap(Inst); + } + + /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the + /// entry for \p Old and replace all uses of \p Old with \p New. void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { // We need to remove Old from the ShapeMap otherwise RAUW will replace it // with New. We should only add New it it supportsShapeInfo so we insert @@ -871,13 +888,13 @@ class LowerMatrixIntrinsics { void liftTranspose(Instruction &I) { // Erase dead Instructions after lifting transposes from binops. - auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { + auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) { if (T.use_empty()) - T.eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(&T); if (A->use_empty()) - cast(A)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(A)); if (A != B && B->use_empty()) - cast(B)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(B)); }; Value *A, *B, *AT, *BT; @@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics { m_Value(Arg)))) { auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); - cast(Op)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(Op)); return; } else if (match(Op, m_Intrinsic( m_Value(Arg)))) { @@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics { // Mark eliminated instructions as fused and remove them. FusedInsts.insert(Store); FusedInsts.insert(MatMul); - Store->eraseFromParent(); - MatMul->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(Store); + eraseFromParentAndRemoveFromShapeMap(MatMul); if (LoadOp0->hasNUses(0)) { FusedInsts.insert(LoadOp0); - LoadOp0->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(LoadOp0); } if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { FusedInsts.insert(LoadOp1); - LoadOp1->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(LoadOp1); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll index 2fd77e245a34e..aadaf1ffffb23 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll @@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>, declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1 declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0 + +define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]] +; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]]) +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP11]] +; +entry: + %t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2) + %shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> + %t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1) + %m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1) + ret <1 x i32> %m +} + +declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll index fcf83b03bc3d2..1b3b41d8cfe1f 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll @@ -144,8 +144,28 @@ entry: ret <6 x double> %mul } +define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) { +; CHECK-LABEL: define void @test_remove_entries_from_shape_map( +; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2) +; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]] +; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2) +; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4 +; CHECK-NEXT: ret void +; +entry: + %m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2) + %add = fadd <6 x float> %c, %m + %t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2) + store <6 x float> %t, ptr %dst, align 4 + ret void +} + declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32) declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32) declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32) declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32) declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32) +declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg) +declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)