You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[MLIR][NVVM] Fix the datatype error for nvvm.mma.sync when the operand is bf16 (#122664)
The PR fixes the datatype error for `nvvm.mma.sync` when the operand is
`bf16`. This operation originally requires the A/B type to be `f16x2`
for the `bf16` MMA. However, it violates the NVVM intrinsic
[[here](https://github.com/xiaoleis-nv/llvm-project/blob/372044ee09d39942925824f8f335aef40bfe92f0/llvm/include/llvm/IR/IntrinsicsNVVM.td#L119)],
where the A/B operand type should be `i32`. This is a bug, and there are
no tests in MLIR that cover this datatype.
```
// mma bf16 -> s32 @ m16n8k16/m16n8k8
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
```
This PR addresses this bug and adds tests to guarantee correctness.
Co-authored-by: Xiaolei Shi <[email protected]>
0 commit comments