Skip to content

Commit 9e6f2e2

Browse files
committed
cuda : add amd dpp version of warp_reduce_sum for half2
1 parent 7a3f7e9 commit 9e6f2e2

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

ggml-cuda/common.cuh

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
321321
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
322322
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
323323
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
324+
#define hip_move_dpph2(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
325+
hip_move_dpph2_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
326+
#define hip_ds_swizzleh2(src, pattern) hip_ds_swizzleh2_N<(pattern)>((src))
324327

325328
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
326329
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
@@ -334,6 +337,30 @@ static __device__ __forceinline__ float hip_move_dppf_N(float x) {
334337
return tmp.val;
335338
}
336339

340+
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
341+
static __device__ __forceinline__ half2 hip_move_dpph2_N(half2 x) {
342+
typedef union half2_b32 {
343+
half2 val;
344+
int b32;
345+
} half2_b32_t;
346+
half2_b32_t tmp;
347+
tmp.val = x;
348+
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
349+
return tmp.val;
350+
}
351+
352+
template <int pattern>
353+
static __device__ __forceinline__ half2 hip_ds_swizzleh2_N(half2 src) {
354+
typedef union half2_b32 {
355+
half2 val;
356+
int b32;
357+
} half2_b32_t;
358+
half2_b32_t tmp;
359+
tmp.val = src;
360+
tmp.b32 = __builtin_amdgcn_ds_swizzle(tmp.b32, pattern);
361+
return tmp.val;
362+
}
363+
337364
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
338365
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
339366
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
@@ -357,6 +384,15 @@ static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
357384
return a;
358385
}
359386

387+
static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 x) {
388+
x += hip_ds_swizzleh2(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
389+
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
390+
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
391+
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
392+
x += hip_move_dpph2(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
393+
return x;
394+
}
395+
360396
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
361397
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
362398
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
@@ -428,13 +464,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
428464
#if FP16_AVAILABLE
429465

430466
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
431-
#pragma unroll
432-
for (int mask = 16; mask > 0; mask >>= 1) {
433-
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
434-
reinterpret_cast<half&>(a.x) += __low2half(a_other);
435-
reinterpret_cast<half&>(a.y) += __high2half(a_other);
436-
}
437-
return a;
467+
return warp_reduce_sum_impl_amd(a);
438468
#else
439469
#pragma unroll
440470
for (int mask = 16; mask > 0; mask >>= 1) {

0 commit comments

Comments
 (0)