Skip to content

Commit 7a3f7e9

Browse files
committed
cuda : use amd wave sharing intrinsics for warp_reduce functions
1 parent fed0108 commit 7a3f7e9

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

ggml-cuda/common.cuh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,57 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
315315
#endif
316316
return c;
317317
}
318+
319+
#ifdef __HIP_PLATFORM_AMD__
320+
#define AMD_SWIZZLE_MASK(and_mask, or_mask, xor_mask) ((and_mask) | ((or_mask)<<5) | ((xor_mask)<<10)) // 5-bit masks applied sequentially to the thread id
321+
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
322+
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
323+
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
324+
325+
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
326+
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
327+
typedef union float_b32 {
328+
float val;
329+
int b32;
330+
} float_b32_t;
331+
float_b32_t tmp;
332+
tmp.val = x;
333+
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
334+
return tmp.val;
335+
}
336+
337+
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
338+
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
339+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
340+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
341+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
342+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
343+
return x;
344+
}
345+
346+
static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
347+
a.x += __hip_ds_swizzlef(a.x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
348+
a.y += __hip_ds_swizzlef(a.y, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
349+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
350+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
351+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
352+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
353+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
354+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
355+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
356+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
357+
return a;
358+
}
359+
360+
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
361+
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
362+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
363+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, false));
364+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, false));
365+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, false));
366+
return x;
367+
}
368+
#endif // __HIP_PLATFORM_AMD__
318369
#endif // defined(GGML_USE_HIPBLAS)
319370

320371
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
@@ -349,20 +400,28 @@ static __device__ void no_device_code(
349400
#endif // __CUDA_ARCH__
350401

351402
static __device__ __forceinline__ float warp_reduce_sum(float x) {
403+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
404+
return warp_reduce_sum_impl_amd(x);
405+
#else
352406
#pragma unroll
353407
for (int mask = 16; mask > 0; mask >>= 1) {
354408
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
355409
}
356410
return x;
411+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
357412
}
358413

359414
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
415+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
416+
return warp_reduce_sum_impl_amd(a);
417+
#else
360418
#pragma unroll
361419
for (int mask = 16; mask > 0; mask >>= 1) {
362420
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
363421
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
364422
}
365423
return a;
424+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
366425
}
367426

368427
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
@@ -391,11 +450,15 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
391450
}
392451

393452
static __device__ __forceinline__ float warp_reduce_max(float x) {
453+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
454+
return warp_reduce_max_impl_amd(x);
455+
#else
394456
#pragma unroll
395457
for (int mask = 16; mask > 0; mask >>= 1) {
396458
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
397459
}
398460
return x;
461+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
399462
}
400463

401464
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {

0 commit comments

Comments
 (0)