Skip to content

Commit e238ee5

Browse files
authored
Merge pull request #8175 from fhahn/matrix-hoist-transpose-fix
Pick fix for matrix transpose hoisting.
2 parents a241f56 + 4a6c695 commit e238ee5

File tree

3 files changed

+329
-162
lines changed

3 files changed

+329
-162
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 118 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
199199
return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
200200
}
201201

202+
namespace {
203+
struct ShapeInfo {
204+
unsigned NumRows;
205+
unsigned NumColumns;
206+
207+
bool IsColumnMajor;
208+
209+
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
210+
: NumRows(NumRows), NumColumns(NumColumns),
211+
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
212+
213+
ShapeInfo(Value *NumRows, Value *NumColumns)
214+
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
215+
cast<ConstantInt>(NumColumns)->getZExtValue()) {}
216+
217+
bool operator==(const ShapeInfo &other) {
218+
return NumRows == other.NumRows && NumColumns == other.NumColumns;
219+
}
220+
bool operator!=(const ShapeInfo &other) { return !(*this == other); }
221+
222+
/// Returns true if shape-information is defined, meaning both dimensions
223+
/// are != 0.
224+
operator bool() const {
225+
assert(NumRows == 0 || NumColumns != 0);
226+
return NumRows != 0;
227+
}
228+
229+
unsigned getStride() const {
230+
if (IsColumnMajor)
231+
return NumRows;
232+
return NumColumns;
233+
}
234+
235+
unsigned getNumVectors() const {
236+
if (IsColumnMajor)
237+
return NumColumns;
238+
return NumRows;
239+
}
240+
241+
/// Returns the transposed shape.
242+
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
243+
};
244+
} // namespace
245+
246+
static bool isUniformShape(Value *V) {
247+
Instruction *I = dyn_cast<Instruction>(V);
248+
if (!I)
249+
return true;
250+
251+
switch (I->getOpcode()) {
252+
case Instruction::FAdd:
253+
case Instruction::FSub:
254+
case Instruction::FMul: // Scalar multiply.
255+
case Instruction::FNeg:
256+
case Instruction::Add:
257+
case Instruction::Mul:
258+
case Instruction::Sub:
259+
return true;
260+
default:
261+
return false;
262+
}
263+
}
264+
265+
/// Return the ShapeInfo for the result of \p I, it it can be determined.
266+
static std::optional<ShapeInfo>
267+
computeShapeInfoForInst(Instruction *I,
268+
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
269+
Value *M;
270+
Value *N;
271+
Value *K;
272+
if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
273+
m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
274+
return ShapeInfo(M, K);
275+
if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
276+
m_Value(N)))) {
277+
// Flip dimensions.
278+
return ShapeInfo(N, M);
279+
}
280+
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
281+
m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
282+
m_Value(N))))
283+
return ShapeInfo(N, M);
284+
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
285+
m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
286+
return ShapeInfo(M, N);
287+
Value *MatrixA;
288+
if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
289+
auto OpShape = ShapeMap.find(MatrixA);
290+
if (OpShape != ShapeMap.end())
291+
return OpShape->second;
292+
}
293+
294+
if (isUniformShape(I)) {
295+
// Find the first operand that has a known shape and use that.
296+
for (auto &Op : I->operands()) {
297+
auto OpShape = ShapeMap.find(Op.get());
298+
if (OpShape != ShapeMap.end())
299+
return OpShape->second;
300+
}
301+
}
302+
return std::nullopt;
303+
}
304+
202305
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
203306
///
204307
/// Currently, the lowering for each matrix intrinsic is done as follows:
@@ -390,48 +493,6 @@ class LowerMatrixIntrinsics {
390493
}
391494
};
392495

393-
struct ShapeInfo {
394-
unsigned NumRows;
395-
unsigned NumColumns;
396-
397-
bool IsColumnMajor;
398-
399-
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
400-
: NumRows(NumRows), NumColumns(NumColumns),
401-
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
402-
403-
ShapeInfo(Value *NumRows, Value *NumColumns)
404-
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
405-
cast<ConstantInt>(NumColumns)->getZExtValue()) {}
406-
407-
bool operator==(const ShapeInfo &other) {
408-
return NumRows == other.NumRows && NumColumns == other.NumColumns;
409-
}
410-
bool operator!=(const ShapeInfo &other) { return !(*this == other); }
411-
412-
/// Returns true if shape-information is defined, meaning both dimensions
413-
/// are != 0.
414-
operator bool() const {
415-
assert(NumRows == 0 || NumColumns != 0);
416-
return NumRows != 0;
417-
}
418-
419-
unsigned getStride() const {
420-
if (IsColumnMajor)
421-
return NumRows;
422-
return NumColumns;
423-
}
424-
425-
unsigned getNumVectors() const {
426-
if (IsColumnMajor)
427-
return NumColumns;
428-
return NumRows;
429-
}
430-
431-
/// Returns the transposed shape.
432-
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
433-
};
434-
435496
/// Maps instructions to their shape information. The shape information
436497
/// describes the shape to be used while lowering. This matches the shape of
437498
/// the result value of the instruction, with the only exceptions being store
@@ -561,25 +622,6 @@ class LowerMatrixIntrinsics {
561622
return true;
562623
}
563624

564-
bool isUniformShape(Value *V) {
565-
Instruction *I = dyn_cast<Instruction>(V);
566-
if (!I)
567-
return true;
568-
569-
switch (I->getOpcode()) {
570-
case Instruction::FAdd:
571-
case Instruction::FSub:
572-
case Instruction::FMul: // Scalar multiply.
573-
case Instruction::FNeg:
574-
case Instruction::Add:
575-
case Instruction::Mul:
576-
case Instruction::Sub:
577-
return true;
578-
default:
579-
return false;
580-
}
581-
}
582-
583625
/// Returns true if shape information can be used for \p V. The supported
584626
/// instructions must match the instructions that can be lowered by this pass.
585627
bool supportsShapeInfo(Value *V) {
@@ -617,43 +659,8 @@ class LowerMatrixIntrinsics {
617659

618660
// New entry, set the value and insert operands
619661
bool Propagate = false;
620-
621-
Value *MatrixA;
622-
Value *MatrixB;
623-
Value *M;
624-
Value *N;
625-
Value *K;
626-
if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
627-
m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
628-
m_Value(N), m_Value(K)))) {
629-
Propagate = setShapeInfo(Inst, {M, K});
630-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
631-
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
632-
// Flip dimensions.
633-
Propagate = setShapeInfo(Inst, {N, M});
634-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
635-
m_Value(MatrixA), m_Value(), m_Value(),
636-
m_Value(), m_Value(M), m_Value(N)))) {
637-
Propagate = setShapeInfo(Inst, {N, M});
638-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
639-
m_Value(), m_Value(), m_Value(), m_Value(M),
640-
m_Value(N)))) {
641-
Propagate = setShapeInfo(Inst, {M, N});
642-
} else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
643-
auto OpShape = ShapeMap.find(MatrixA);
644-
if (OpShape != ShapeMap.end())
645-
setShapeInfo(Inst, OpShape->second);
646-
continue;
647-
} else if (isUniformShape(Inst)) {
648-
// Find the first operand that has a known shape and use that.
649-
for (auto &Op : Inst->operands()) {
650-
auto OpShape = ShapeMap.find(Op.get());
651-
if (OpShape != ShapeMap.end()) {
652-
Propagate |= setShapeInfo(Inst, OpShape->second);
653-
break;
654-
}
655-
}
656-
}
662+
if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
663+
Propagate = setShapeInfo(Inst, *SI);
657664

658665
if (Propagate) {
659666
NewWorkList.push_back(Inst);
@@ -898,20 +905,28 @@ class LowerMatrixIntrinsics {
898905
updateShapeAndReplaceAllUsesWith(I, NewInst);
899906
CleanupBinOp(I, A, B);
900907
}
901-
// A^t + B ^t -> (A + B)^t
908+
// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
909+
// the shape of the second transpose is different, there's a shape conflict
910+
// which gets resolved by picking the shape of the first operand.
902911
else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
903912
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
904913
m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
905914
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
906-
m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
915+
m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
907916
IRBuilder<> Builder(&I);
908-
Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
909-
setShapeInfo(Add, {C, R});
917+
auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
918+
setShapeInfo(Add, {R, C});
910919
MatrixBuilder MBuilder(Builder);
911920
Instruction *NewInst = MBuilder.CreateMatrixTranspose(
912-
Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
921+
Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
913922
updateShapeAndReplaceAllUsesWith(I, NewInst);
923+
assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
924+
computeShapeInfoForInst(&I, ShapeMap) &&
925+
"Shape of new instruction doesn't match original shape.");
914926
CleanupBinOp(I, A, B);
927+
assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
928+
ShapeMap[Add] &&
929+
"Shape of updated addition doesn't match cached shape.");
915930
}
916931
}
917932

0 commit comments

Comments
 (0)