Skip to content

Commit d03f35f

Browse files
[MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)
The PR fixes the datatype error for `nvvm.mma.sync` when the operand is `bf16`. This operation originally requires the A/B type to be `f16x2` for the `bf16` MMA. However, it violates the NVVM intrinsic [[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)], where the A/B operand type should be `i32`. This is a bug, and there are no tests in MLIR that cover this datatype. ``` // mma bf16 -> s32 @ m16n8k16/m16n8k8 !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4), !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty], ``` This PR addresses this bug and adds tests to guarantee correctness. Co-authored-by: Xiaolei Shi <[email protected]>
1 parent 1b199d1 commit d03f35f

File tree

4 files changed

+43
-3
lines changed

4 files changed

+43
-3
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
16991699
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
17001700
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
17011701
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
1702-
| bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
1703-
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
1702+
| bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 |
1703+
| | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 |
17041704
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
17051705
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
17061706
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
445445
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
446446
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
447447
break;
448-
case MMATypes::f16:
449448
case MMATypes::bf16:
449+
kFactor = 8;
450+
multiplicandFragType = i32Ty;
451+
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
452+
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
453+
break;
454+
case MMATypes::f16:
450455
kFactor = 8;
451456
multiplicandFragType = f16x2Ty;
452457
expectedResult.push_back(f16x2x2StructTy);

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
163163
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
164164
}
165165

166+
// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
167+
func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
168+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
169+
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
170+
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
171+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
172+
multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
173+
shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
174+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
175+
}
176+
177+
// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
178+
func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
179+
%b0 : i32, %b1 : i32,
180+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
181+
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
182+
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
183+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
184+
multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
185+
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
186+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
187+
}
188+
166189
// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
167190
func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
168191
%c0 : i32, %c1 : i32) {

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
291291
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
292292
}
293293

294+
// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
295+
llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
296+
%b0 : i32, %b1 : i32,
297+
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
298+
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
299+
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
300+
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
301+
multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
302+
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
303+
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
304+
}
305+
294306
// f32 return type, f16 accumulate type
295307
// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
296308
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,

0 commit comments

Comments
 (0)