Skip to content

[Matrix] Use DenseMap for ShapeMap instead of ValueMap. #118282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
}

/// Erase \p V from \p BB and move \II forward to avoid invalidating
/// iterators.
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
BasicBlock &BB) {
auto *Inst = cast<Instruction>(V);
// Still used, don't erase.
if (!Inst->use_empty())
return;
if (II != BB.rend() && Inst == &*II)
++II;
Inst->eraseFromParent();
}

/// Return true if V is a splat of a value (which is used when multiplying a
/// matrix with a scalar).
static bool isSplat(Value *V) {
Expand Down Expand Up @@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
const DenseMap<Value *, ShapeInfo> &ShapeMap) {
Value *M;
Value *N;
Value *K;
Expand Down Expand Up @@ -493,10 +480,16 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
/// using shape information as well. A ValueMap is used so that when
/// sub-passes like optimizeTransposes performs RAUW the map stays
/// up-to-date.
ValueMap<Value *, ShapeInfo> ShapeMap;
/// using shape information as well. Note that extra care is needed when
/// erasing or RAUW'ing a value that is present in ShapeMap. If the
/// replacement is also a matrix operation, use
/// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
/// ShapeMap. We don't use ValueMap, as there are also cases where we do not
/// want to add shape information for a replacement instruction. When directly
/// erasing a value with an entry in ShapeMap, use
/// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
/// accordingly.
DenseMap<Value *, ShapeInfo> ShapeMap;

/// List of instructions to remove. While lowering, we are not replacing all
/// users of a lowered instruction, if shape information is available and
Expand Down Expand Up @@ -758,6 +751,30 @@ class LowerMatrixIntrinsics {
return Operation(T0, Shape0.t(), T1, Shape1.t());
}

/// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
/// itself.
void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
auto Iter = ShapeMap.find(Inst);
if (Iter != ShapeMap.end())
ShapeMap.erase(Iter);
Inst->eraseFromParent();
}

/// Erase \p V from \p BB and move \II forward to avoid invalidating
/// iterators.
void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
BasicBlock &BB) {
auto *Inst = cast<Instruction>(V);
// Still used, don't erase.
if (!Inst->use_empty())
return;
if (II != BB.rend() && Inst == &*II)
++II;
eraseFromParentAndRemoveFromShapeMap(Inst);
}

/// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
/// entry for \p Old and replace all uses of \p Old with \p New.
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
// with New. We should only add New it it supportsShapeInfo so we insert
Expand Down Expand Up @@ -871,13 +888,13 @@ class LowerMatrixIntrinsics {

void liftTranspose(Instruction &I) {
// Erase dead Instructions after lifting transposes from binops.
auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
if (T.use_empty())
T.eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(&T);
if (A->use_empty())
cast<Instruction>(A)->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
if (A != B && B->use_empty())
cast<Instruction>(B)->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
};

Value *A, *B, *AT, *BT;
Expand Down Expand Up @@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics {
m_Value(Arg)))) {
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
cast<Instruction>(Op)->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
return;
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
Expand Down Expand Up @@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics {
// Mark eliminated instructions as fused and remove them.
FusedInsts.insert(Store);
FusedInsts.insert(MatMul);
Store->eraseFromParent();
MatMul->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(Store);
eraseFromParentAndRemoveFromShapeMap(MatMul);
if (LoadOp0->hasNUses(0)) {
FusedInsts.insert(LoadOp0);
LoadOp0->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(LoadOp0);
}
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
FusedInsts.insert(LoadOp1);
LoadOp1->eraseFromParent();
eraseFromParentAndRemoveFromShapeMap(LoadOp1);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1

declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0

define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
; CHECK-NEXT: ret <1 x i32> [[TMP11]]
;
entry:
%t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
%shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
%t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
%m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
ret <1 x i32> %m
}

declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,28 @@ entry:
ret <6 x double> %mul
}

define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
; CHECK-NEXT: ret void
;
entry:
%m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
%add = fadd <6 x float> %c, %m
%t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
store <6 x float> %t, ptr %dst, align 4
ret void
}

declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)
Loading