Skip to content

[MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 #122664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 13, 2025

Conversation

xiaoleis-nv
Copy link
Contributor

The PR fixes the datatype error for nvvm.mma.sync when the operand is bf16. This operation originally requires the A/B type to be f16x2 for the bf16 MMA. However, it violates the NVVM intrinsic [here], where the A/B operand type should be i32. This is a bug, and there are no tests in MLIR that cover this datatype.

    // mma bf16 -> s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],

This PR addresses this bug and adds tests to guarantee correctness.

@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir-llvm

Author: None (xiaoleis-nv)

Changes

The PR fixes the datatype error for nvvm.mma.sync when the operand is bf16. This operation originally requires the A/B type to be f16x2 for the bf16 MMA. However, it violates the NVVM intrinsic [here], where the A/B operand type should be i32. This is a bug, and there are no tests in MLIR that cover this datatype.

    // mma bf16 -> s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],

This PR addresses this bug and adds tests to guarantee correctness.


Full diff: https://github.com/llvm/llvm-project/pull/122664.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+6-1)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+23)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
     |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
     |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
-    | bf16     | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
-    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | bf16     | .m16n8k8  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
+    |          | .m16n8k16 | row     | col     | 4x i32   | 2x i32   | 4x f32            |
     | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
     |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
     | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
       break;
-    case MMATypes::f16:
     case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = i32Ty;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+      break;
+    case MMATypes::f16:
       kFactor = 8;
       multiplicandFragType = f16x2Ty;
       expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // f32 return type, f16 accumulate type
 // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,

@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir

Author: None (xiaoleis-nv)

Changes

The PR fixes the datatype error for nvvm.mma.sync when the operand is bf16. This operation originally requires the A/B type to be f16x2 for the bf16 MMA. However, it violates the NVVM intrinsic [here], where the A/B operand type should be i32. This is a bug, and there are no tests in MLIR that cover this datatype.

    // mma bf16 -&gt; s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],

This PR addresses this bug and adds tests to guarantee correctness.


Full diff: https://github.com/llvm/llvm-project/pull/122664.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+6-1)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+23)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b9097e9bbca2c..04042903e343ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1699,8 +1699,8 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
     |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
     |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
-    | bf16     | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
-    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | bf16     | .m16n8k8  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
+    |          | .m16n8k16 | row     | col     | 4x i32   | 2x i32   | 4x f32            |
     | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
     |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
     | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 838159d676545d..d8fde3e765ac49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -445,8 +445,13 @@ LogicalResult MmaOp::verify() {
       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
       break;
-    case MMATypes::f16:
     case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = i32Ty;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+      break;
+    case MMATypes::f16:
       kFactor = 8;
       multiplicandFragType = f16x2Ty;
       expectedResult.push_back(f16x2x2StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index a7bdceba01c1e8..4c3b6648a41c00 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -163,6 +163,29 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
+func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 2d7710e7cbf279..09e98765413f0c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -291,6 +291,18 @@ llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
+llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                              %b0 : i32, %b1 : i32,
+                              %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.bf16
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
 // f32 return type, f16 accumulate type
 // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,

@xiaoleis-nv
Copy link
Contributor Author

@christopherbate @joker-eph , Could you please review this PR as well? Thank you.

@grypp
Copy link
Member

grypp commented Jan 13, 2025

Looks good to me. I will wait others to review. If there is no objection, feel free to land it,
I wonder, though, was the bug always there? I vaguely remember that mma.sync bf16 was working before.

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@durga4github
Copy link
Contributor

@schwarzschild-radius , Please take a look when you get a chance

@durga4github
Copy link
Contributor

Submitting this as per Xiaolei's request

@durga4github durga4github merged commit d03f35f into llvm:main Jan 13, 2025
11 checks passed
kazutakahirata pushed a commit to kazutakahirata/llvm-project that referenced this pull request Jan 13, 2025
…d is bf16 (llvm#122664)

The PR fixes the datatype error for `nvvm.mma.sync` when the operand is
`bf16`. This operation originally requires the A/B type to be `f16x2`
for the `bf16` MMA. However, it violates the NVVM intrinsic
[[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)],
where the A/B operand type should be `i32`. This is a bug, and there are
no tests in MLIR that cover this datatype.

```
    // mma bf16 -> s32 @ m16n8k16/m16n8k8
    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
```

This PR addresses this bug and adds tests to guarantee correctness.

Co-authored-by: Xiaolei Shi <[email protected]>
@xiaoleis-nv xiaoleis-nv deleted the fix/fix_bf16_nvvm.mma.sync branch January 14, 2025 01:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants