diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index e059d31ca5842..034f3e2d16e94 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -200,7 +200,9 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { /// Return true if this integer extend op can be folded into a contract op. template static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { - if (!isa(extOp.getOperand().getDefiningOp())) + auto transferReadOp = + extOp.getOperand().template getDefiningOp(); + if (!transferReadOp) return false; return llvm::all_of(extOp->getUsers(), llvm::IsaPred); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir index 8526ff1392599..b8ac63f89af33 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -517,3 +517,22 @@ func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf1 vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32> return } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// Ensure that no crash occurs when the predecessor operation +// of `ext` is not `transfer_read`. + +// CHECK-LABEL: func @test_unsupported +// CHECK: vector.contract +func.func @test_unsupported(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi64>) -> vector<4x4xi64 > { + %0 = arith.extui %arg0 : vector<4x4xi32> to vector<4x4xi64> + %1 = arith.extui %arg1 : vector<4x4xi32> to vector<4x4xi64> + %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %0, %1, %arg2 : vector<4x4xi64>, vector<4x4xi64> into vector<4x4xi64> + return %2 : vector<4x4xi64> +}