Skip to content

Commit 3452149

Browse files
authored
[mlir][AMDGPU] Support vector<2xbf16> packed atomic fadd (#113929)
Now that we use LLVM's native bfloat types in the AMDGPU lowering, enable vector<2xbf16> for AMDGPU.
1 parent 9ce0a61 commit 3452149

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def AMDGPU_RawBufferAtomicCmpswapOp :
254254
def AMDGPU_RawBufferAtomicFaddOp :
255255
AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>,
256256
AttrSizedOperandSegments]>,
257-
Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16]>]>:$value,
257+
Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$value,
258258
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
259259
Variadic<I32>:$indices,
260260
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
9898
// bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
9999
// and the total load size is >= 32, use a vector load of N / (bitsize(T) /
100100
// 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
101-
// so bitcast any floats to integers. On top of all this, cast bfloat
102-
// (vectors) to i16 since the backend doesn't currently support bfloat on
103-
// these operations.
101+
// so bitcast any floats to integers.
104102
Type llvmBufferValType = llvmWantedDataType;
105-
if (wantedDataType.isBF16())
106-
llvmBufferValType = rewriter.getI16Type();
107-
if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
108-
if (wantedVecType.getElementType().isBF16())
109-
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
110103
if (atomicCmpData) {
111104
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
112105
llvmBufferValType = this->getTypeConverter()->convertType(

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: mem
163163
func.return
164164
}
165165

166+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16
167+
func.func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16(%value: vector<2xbf16>, %buf: memref<64xbf16>, %idx: i32) {
168+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
169+
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
170+
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
171+
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
172+
// CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xbf16>
173+
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xbf16> -> memref<64xbf16>, i32
174+
func.return
175+
}
176+
166177
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32
167178
func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
168179
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)

0 commit comments

Comments
 (0)