@@ -199,6 +199,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
199
199
return Builder.CreatePointerCast (VecStart, VecPtrType, " vec.cast" );
200
200
}
201
201
202
+ namespace {
203
+ struct ShapeInfo {
204
+ unsigned NumRows;
205
+ unsigned NumColumns;
206
+
207
+ bool IsColumnMajor;
208
+
209
+ ShapeInfo (unsigned NumRows = 0 , unsigned NumColumns = 0 )
210
+ : NumRows(NumRows), NumColumns(NumColumns),
211
+ IsColumnMajor (MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
212
+
213
+ ShapeInfo (Value *NumRows, Value *NumColumns)
214
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
215
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
216
+
217
+ bool operator ==(const ShapeInfo &other) {
218
+ return NumRows == other.NumRows && NumColumns == other.NumColumns ;
219
+ }
220
+ bool operator !=(const ShapeInfo &other) { return !(*this == other); }
221
+
222
+ // / Returns true if shape-information is defined, meaning both dimensions
223
+ // / are != 0.
224
+ operator bool () const {
225
+ assert (NumRows == 0 || NumColumns != 0 );
226
+ return NumRows != 0 ;
227
+ }
228
+
229
+ unsigned getStride () const {
230
+ if (IsColumnMajor)
231
+ return NumRows;
232
+ return NumColumns;
233
+ }
234
+
235
+ unsigned getNumVectors () const {
236
+ if (IsColumnMajor)
237
+ return NumColumns;
238
+ return NumRows;
239
+ }
240
+
241
+ // / Returns the transposed shape.
242
+ ShapeInfo t () const { return ShapeInfo (NumColumns, NumRows); }
243
+ };
244
+ } // namespace
245
+
246
+ static bool isUniformShape (Value *V) {
247
+ Instruction *I = dyn_cast<Instruction>(V);
248
+ if (!I)
249
+ return true ;
250
+
251
+ switch (I->getOpcode ()) {
252
+ case Instruction::FAdd:
253
+ case Instruction::FSub:
254
+ case Instruction::FMul: // Scalar multiply.
255
+ case Instruction::FNeg:
256
+ case Instruction::Add:
257
+ case Instruction::Mul:
258
+ case Instruction::Sub:
259
+ return true ;
260
+ default :
261
+ return false ;
262
+ }
263
+ }
264
+
265
+ // / Return the ShapeInfo for the result of \p I, it it can be determined.
266
+ static std::optional<ShapeInfo>
267
+ computeShapeInfoForInst (Instruction *I,
268
+ const ValueMap<Value *, ShapeInfo> &ShapeMap) {
269
+ Value *M;
270
+ Value *N;
271
+ Value *K;
272
+ if (match (I, m_Intrinsic<Intrinsic::matrix_multiply>(
273
+ m_Value (), m_Value (), m_Value (M), m_Value (N), m_Value (K))))
274
+ return ShapeInfo (M, K);
275
+ if (match (I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (), m_Value (M),
276
+ m_Value (N)))) {
277
+ // Flip dimensions.
278
+ return ShapeInfo (N, M);
279
+ }
280
+ if (match (I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
281
+ m_Value (), m_Value (), m_Value (), m_Value (), m_Value (M),
282
+ m_Value (N))))
283
+ return ShapeInfo (N, M);
284
+ if (match (I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
285
+ m_Value (), m_Value (), m_Value (), m_Value (M), m_Value (N))))
286
+ return ShapeInfo (M, N);
287
+ Value *MatrixA;
288
+ if (match (I, m_Store (m_Value (MatrixA), m_Value ()))) {
289
+ auto OpShape = ShapeMap.find (MatrixA);
290
+ if (OpShape != ShapeMap.end ())
291
+ return OpShape->second ;
292
+ }
293
+
294
+ if (isUniformShape (I)) {
295
+ // Find the first operand that has a known shape and use that.
296
+ for (auto &Op : I->operands ()) {
297
+ auto OpShape = ShapeMap.find (Op.get ());
298
+ if (OpShape != ShapeMap.end ())
299
+ return OpShape->second ;
300
+ }
301
+ }
302
+ return std::nullopt;
303
+ }
304
+
202
305
// / LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
203
306
// /
204
307
// / Currently, the lowering for each matrix intrinsic is done as follows:
@@ -390,48 +493,6 @@ class LowerMatrixIntrinsics {
390
493
}
391
494
};
392
495
393
- struct ShapeInfo {
394
- unsigned NumRows;
395
- unsigned NumColumns;
396
-
397
- bool IsColumnMajor;
398
-
399
- ShapeInfo (unsigned NumRows = 0 , unsigned NumColumns = 0 )
400
- : NumRows(NumRows), NumColumns(NumColumns),
401
- IsColumnMajor (MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
402
-
403
- ShapeInfo (Value *NumRows, Value *NumColumns)
404
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
405
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
406
-
407
- bool operator ==(const ShapeInfo &other) {
408
- return NumRows == other.NumRows && NumColumns == other.NumColumns ;
409
- }
410
- bool operator !=(const ShapeInfo &other) { return !(*this == other); }
411
-
412
- // / Returns true if shape-information is defined, meaning both dimensions
413
- // / are != 0.
414
- operator bool () const {
415
- assert (NumRows == 0 || NumColumns != 0 );
416
- return NumRows != 0 ;
417
- }
418
-
419
- unsigned getStride () const {
420
- if (IsColumnMajor)
421
- return NumRows;
422
- return NumColumns;
423
- }
424
-
425
- unsigned getNumVectors () const {
426
- if (IsColumnMajor)
427
- return NumColumns;
428
- return NumRows;
429
- }
430
-
431
- // / Returns the transposed shape.
432
- ShapeInfo t () const { return ShapeInfo (NumColumns, NumRows); }
433
- };
434
-
435
496
// / Maps instructions to their shape information. The shape information
436
497
// / describes the shape to be used while lowering. This matches the shape of
437
498
// / the result value of the instruction, with the only exceptions being store
@@ -561,25 +622,6 @@ class LowerMatrixIntrinsics {
561
622
return true ;
562
623
}
563
624
564
- bool isUniformShape (Value *V) {
565
- Instruction *I = dyn_cast<Instruction>(V);
566
- if (!I)
567
- return true ;
568
-
569
- switch (I->getOpcode ()) {
570
- case Instruction::FAdd:
571
- case Instruction::FSub:
572
- case Instruction::FMul: // Scalar multiply.
573
- case Instruction::FNeg:
574
- case Instruction::Add:
575
- case Instruction::Mul:
576
- case Instruction::Sub:
577
- return true ;
578
- default :
579
- return false ;
580
- }
581
- }
582
-
583
625
// / Returns true if shape information can be used for \p V. The supported
584
626
// / instructions must match the instructions that can be lowered by this pass.
585
627
bool supportsShapeInfo (Value *V) {
@@ -617,43 +659,8 @@ class LowerMatrixIntrinsics {
617
659
618
660
// New entry, set the value and insert operands
619
661
bool Propagate = false ;
620
-
621
- Value *MatrixA;
622
- Value *MatrixB;
623
- Value *M;
624
- Value *N;
625
- Value *K;
626
- if (match (Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
627
- m_Value (MatrixA), m_Value (MatrixB), m_Value (M),
628
- m_Value (N), m_Value (K)))) {
629
- Propagate = setShapeInfo (Inst, {M, K});
630
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
631
- m_Value (MatrixA), m_Value (M), m_Value (N)))) {
632
- // Flip dimensions.
633
- Propagate = setShapeInfo (Inst, {N, M});
634
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
635
- m_Value (MatrixA), m_Value (), m_Value (),
636
- m_Value (), m_Value (M), m_Value (N)))) {
637
- Propagate = setShapeInfo (Inst, {N, M});
638
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
639
- m_Value (), m_Value (), m_Value (), m_Value (M),
640
- m_Value (N)))) {
641
- Propagate = setShapeInfo (Inst, {M, N});
642
- } else if (match (Inst, m_Store (m_Value (MatrixA), m_Value ()))) {
643
- auto OpShape = ShapeMap.find (MatrixA);
644
- if (OpShape != ShapeMap.end ())
645
- setShapeInfo (Inst, OpShape->second );
646
- continue ;
647
- } else if (isUniformShape (Inst)) {
648
- // Find the first operand that has a known shape and use that.
649
- for (auto &Op : Inst->operands ()) {
650
- auto OpShape = ShapeMap.find (Op.get ());
651
- if (OpShape != ShapeMap.end ()) {
652
- Propagate |= setShapeInfo (Inst, OpShape->second );
653
- break ;
654
- }
655
- }
656
- }
662
+ if (auto SI = computeShapeInfoForInst (Inst, ShapeMap))
663
+ Propagate = setShapeInfo (Inst, *SI);
657
664
658
665
if (Propagate) {
659
666
NewWorkList.push_back (Inst);
@@ -898,20 +905,28 @@ class LowerMatrixIntrinsics {
898
905
updateShapeAndReplaceAllUsesWith (I, NewInst);
899
906
CleanupBinOp (I, A, B);
900
907
}
901
- // A^t + B ^t -> (A + B)^t
908
+ // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
909
+ // the shape of the second transpose is different, there's a shape conflict
910
+ // which gets resolved by picking the shape of the first operand.
902
911
else if (match (&I, m_FAdd (m_Value (A), m_Value (B))) &&
903
912
match (A, m_Intrinsic<Intrinsic::matrix_transpose>(
904
913
m_Value (AT), m_ConstantInt (R), m_ConstantInt (C))) &&
905
914
match (B, m_Intrinsic<Intrinsic::matrix_transpose>(
906
- m_Value (BT), m_ConstantInt (R ), m_ConstantInt (C )))) {
915
+ m_Value (BT), m_ConstantInt (), m_ConstantInt ()))) {
907
916
IRBuilder<> Builder (&I);
908
- Value *Add = cast<Instruction>(Builder.CreateFAdd (AT, BT, " mfadd" ));
909
- setShapeInfo (Add, {C, R });
917
+ auto *Add = cast<Instruction>(Builder.CreateFAdd (AT, BT, " mfadd" ));
918
+ setShapeInfo (Add, {R, C });
910
919
MatrixBuilder MBuilder (Builder);
911
920
Instruction *NewInst = MBuilder.CreateMatrixTranspose (
912
- Add, C ->getZExtValue (), R ->getZExtValue (), " mfadd_t" );
921
+ Add, R ->getZExtValue (), C ->getZExtValue (), " mfadd_t" );
913
922
updateShapeAndReplaceAllUsesWith (I, NewInst);
923
+ assert (computeShapeInfoForInst (NewInst, ShapeMap) ==
924
+ computeShapeInfoForInst (&I, ShapeMap) &&
925
+ " Shape of new instruction doesn't match original shape." );
914
926
CleanupBinOp (I, A, B);
927
+ assert (computeShapeInfoForInst (Add, ShapeMap).value_or (ShapeMap[Add]) ==
928
+ ShapeMap[Add] &&
929
+ " Shape of updated addition doesn't match cached shape." );
915
930
}
916
931
}
917
932
0 commit comments