Skip to content

Commit 927559d

Browse files
authored
[mlir][vector] Fix a crash in VectorToGPU (#113454)
This PR fixes a crash in `VectorToGPU` when the operand of `extOp` is a function argument, which cannot be retrieved using `getDefiningOp`. Fixes #107967.
1 parent 47c1abf commit 927559d

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
200200
/// Return true if this integer extend op can be folded into a contract op.
201201
template <typename ExtOpTy>
202202
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203-
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
203+
auto transferReadOp =
204+
extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205+
if (!transferReadOp)
204206
return false;
205207
return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
206208
}

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,22 @@ func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf1
517517
vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
518518
return
519519
}
520+
521+
// -----
522+
523+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
524+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
525+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
526+
527+
// Ensure that no crash occurs when the predecessor operation
528+
// of `ext` is not `transfer_read`.
529+
530+
// CHECK-LABEL: func @test_unsupported
531+
// CHECK: vector.contract
532+
func.func @test_unsupported(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi64>) -> vector<4x4xi64 > {
533+
%0 = arith.extui %arg0 : vector<4x4xi32> to vector<4x4xi64>
534+
%1 = arith.extui %arg1 : vector<4x4xi32> to vector<4x4xi64>
535+
%2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
536+
%0, %1, %arg2 : vector<4x4xi64>, vector<4x4xi64> into vector<4x4xi64>
537+
return %2 : vector<4x4xi64>
538+
}

0 commit comments

Comments
 (0)