Skip to content

Commit 6e61126

Browse files
authored
[mlir] AMDGPUToROCDL: handle 1-element vectors (#128266)
Buffer intrinsics doesn't support 1-element vectors, cast them to scalars.
1 parent 0bd66c4 commit 6e61126

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
167167
}
168168
}
169169
}
170+
if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
171+
// Buffer intrinsics doesn't support 1-element vectors, cast them to
172+
// scalars.
173+
if (vecType.getNumElements() == 1)
174+
llvmBufferValType = vecType.getElementType();
175+
}
170176

171177
SmallVector<Value, 6> args;
172178
if (storeData) {

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32)
7676
func.return %0 : i32
7777
}
7878

79+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_1xi32
80+
func.func @gpu_gcn_raw_buffer_load_1xi32(%buf: memref<64xi32>, %idx: i32) -> vector<1xi32> {
81+
// CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
82+
// CHECK: %[[cast:.*]] = llvm.bitcast %[[ret]] : i32 to vector<1xi32>
83+
// CHECK: return %[[cast]]
84+
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> vector<1xi32>
85+
func.return %0 : vector<1xi32>
86+
}
87+
7988
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi32
8089
func.func @gpu_gcn_raw_buffer_load_2xi32(%buf: memref<64xi32>, %idx: i32) -> vector<2xi32> {
8190
// CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi32>
@@ -159,6 +168,14 @@ func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx:
159168
func.return
160169
}
161170

171+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_1xf32
172+
func.func @gpu_gcn_raw_buffer_store_1xf32(%value: vector<1xf32>, %buf: memref<64xf32>, %idx: i32) {
173+
// CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<1xf32> to f32
174+
// CHECK: rocdl.raw.ptr.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
175+
amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<1xf32> -> memref<64xf32>, i32
176+
func.return
177+
}
178+
162179
// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_2xi8
163180
func.func @gpu_gcn_raw_buffer_store_2xi8(%value: vector<2xi8>, %buf: memref<64xi8>, %idx: i32) {
164181
// CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<2xi8> to i16

0 commit comments

Comments
 (0)