diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b29228ef87ea7..b8574bbbee345 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -167,6 +167,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { } } } + if (auto vecType = dyn_cast(llvmBufferValType)) { + // Buffer intrinsics doesn't support 1-element vectors, cast them to + // scalars. + if (vecType.getNumElements() == 1) + llvmBufferValType = vecType.getElementType(); + } SmallVector args; if (storeData) { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 062b63c076c3c..8b2f5788721a1 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -76,6 +76,15 @@ func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32) func.return %0 : i32 } +// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_1xi32 +func.func @gpu_gcn_raw_buffer_load_1xi32(%buf: memref<64xi32>, %idx: i32) -> vector<1xi32> { + // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 + // CHECK: %[[cast:.*]] = llvm.bitcast %[[ret]] : i32 to vector<1xi32> + // CHECK: return %[[cast]] + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> vector<1xi32> + func.return %0 : vector<1xi32> +} + // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi32 func.func @gpu_gcn_raw_buffer_load_2xi32(%buf: memref<64xi32>, %idx: i32) -> vector<2xi32> { // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi32> @@ -159,6 +168,14 @@ func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: func.return } +// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_1xf32 +func.func @gpu_gcn_raw_buffer_store_1xf32(%value: vector<1xf32>, %buf: memref<64xf32>, %idx: i32) { + // CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<1xf32> to f32 + // CHECK: rocdl.raw.ptr.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 + amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<1xf32> -> memref<64xf32>, i32 + func.return +} + // CHECK-LABEL: func @gpu_gcn_raw_buffer_store_2xi8 func.func @gpu_gcn_raw_buffer_store_2xi8(%value: vector<2xi8>, %buf: memref<64xi8>, %idx: i32) { // CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<2xi8> to i16