-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 #122664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 #122664
Conversation
@llvm/pr-subscribers-mlir-llvm Author: None (xiaoleis-nv) ChangesThe PR fixes the datatype error for
This PR addresses this bug and adds tests to guarantee correctness. Full diff: https://github.com/llvm/llvm-project/pull/122664.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
- | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
- | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+ | bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 |
+ | | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 |
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
- case MMATypes::f16:
case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ break;
+ case MMATypes::f16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
%c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// f32 return type, f16 accumulate type
// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
@llvm/pr-subscribers-mlir Author: None (xiaoleis-nv) ChangesThe PR fixes the datatype error for
This PR addresses this bug and adds tests to guarantee correctness. Full diff: https://github.com/llvm/llvm-project/pull/122664.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
- | bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
- | | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+ | bf16 | .m16n8k8 | row | col | 2x i32 | 1x i32 | 4x f32 |
+ | | .m16n8k16 | row | col | 4x i32 | 2x i32 | 4x f32 |
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
- case MMATypes::f16:
case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ break;
+ case MMATypes::f16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
%c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
// f32 return type, f16 accumulate type
// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
@christopherbate @joker-eph , Could you please review this PR as well? Thank you. |
Looks good to me. I will wait others to review. If there is no objection, feel free to land it, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@schwarzschild-radius , Please take a look when you get a chance |
Submitting this as per Xiaolei's request |
…d is bf16 (llvm#122664) The PR fixes the datatype error for `nvvm.mma.sync` when the operand is `bf16`. This operation originally requires the A/B type to be `f16x2` for the `bf16` MMA. However, it violates the NVVM intrinsic [[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)], where the A/B operand type should be `i32`. This is a bug, and there are no tests in MLIR that cover this datatype. ``` // mma bf16 -> s32 @ m16n8k16/m16n8k8 !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4), !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2), !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty], ``` This PR addresses this bug and adds tests to guarantee correctness. Co-authored-by: Xiaolei Shi <[email protected]>
The PR fixes the datatype error for
nvvm.mma.sync
when the operand isbf16
. This operation originally requires the A/B type to bef16x2
for thebf16
MMA. However, it violates the NVVM intrinsic [here], where the A/B operand type should bei32
. This is a bug, and there are no tests in MLIR that cover this datatype.This PR addresses this bug and adds tests to guarantee correctness.