diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 0b9097e9bbca2..04042903e343e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { | f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 | | | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | - | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 | - | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 | + | bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 | + | | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 | | tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 | | | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 | | u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 | diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 838159d676545..d8fde3e765ac4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() { expectedResult.push_back(LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty})); break; - case MMATypes::f16: case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + break; + case MMATypes::f16: kFactor = 8; multiplicandFragType = f16x2Ty; expectedResult.push_back(f16x2x2StructTy); diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index a7bdceba01c1e..4c3b6648a41c0 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16 +func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + +// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16 +func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { + // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32, %c0 : i32, %c1 : i32) { diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 2d7710e7cbf27..09e98765413f0 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16 +llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> { + // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16 + %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> +} + // f32 return type, f16 accumulate type // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,