@@ -321,6 +321,9 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
321
321
#define AMD_DPP_ROW_RR (x ) (0x120 +(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
322
322
#define hip_move_dppf (src, dpp_ctrl, row_mask, bank_mask, bound_ctrl ) \
323
323
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
324
+ #define hip_move_dpph2 (src, dpp_ctrl, row_mask, bank_mask, bound_ctrl ) \
325
+ hip_move_dpph2_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
326
+ #define hip_ds_swizzleh2 (src, pattern ) hip_ds_swizzleh2_N<(pattern)>((src))
324
327
325
328
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
326
329
static __device__ __forceinline__ float hip_move_dppf_N (float x) {
@@ -334,6 +337,30 @@ static __device__ __forceinline__ float hip_move_dppf_N(float x) {
334
337
return tmp.val ;
335
338
}
336
339
340
+ template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
341
+ static __device__ __forceinline__ half2 hip_move_dpph2_N (half2 x) {
342
+ typedef union half2_b32 {
343
+ half2 val;
344
+ int b32;
345
+ } half2_b32_t ;
346
+ half2_b32_t tmp;
347
+ tmp.val = x;
348
+ tmp.b32 = __builtin_amdgcn_mov_dpp (tmp.b32 , dpp_ctrl, row_mask, bank_mask, bound_ctrl);
349
+ return tmp.val ;
350
+ }
351
+
352
+ template <int pattern>
353
+ static __device__ __forceinline__ half2 hip_ds_swizzleh2_N (half2 src) {
354
+ typedef union half2_b32 {
355
+ half2 val;
356
+ int b32;
357
+ } half2_b32_t ;
358
+ half2_b32_t tmp;
359
+ tmp.val = src;
360
+ tmp.b32 = __builtin_amdgcn_ds_swizzle (tmp.b32 , pattern);
361
+ return tmp.val ;
362
+ }
363
+
337
364
static __device__ __forceinline__ float warp_reduce_sum_impl_amd (float x) {
338
365
x += __hip_ds_swizzlef (x, AMD_SWIZZLE_MASK (0x1F , 0 , 0x10 )); // swap neighbouring groups of 16 lanes
339
366
x += hip_move_dppf (x, AMD_DPP_ROW_RR (8 ), 0xF , 0xF , true );
@@ -357,6 +384,15 @@ static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
357
384
return a;
358
385
}
359
386
387
+ static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd (half2 x) {
388
+ x += hip_ds_swizzleh2 (x, AMD_SWIZZLE_MASK (0x1F , 0 , 0x10 ));
389
+ x += hip_move_dpph2 (x, AMD_DPP_ROW_RR (8 ), 0xF , 0xF , true );
390
+ x += hip_move_dpph2 (x, AMD_DPP_ROW_RR (4 ), 0xF , 0xF , true );
391
+ x += hip_move_dpph2 (x, AMD_DPP_ROW_RR (2 ), 0xF , 0xF , true );
392
+ x += hip_move_dpph2 (x, AMD_DPP_ROW_RR (1 ), 0xF , 0xF , true );
393
+ return x;
394
+ }
395
+
360
396
static __device__ __forceinline__ float warp_reduce_max_impl_amd (float x) {
361
397
x = fmaxf (x, __hip_ds_swizzlef (x, AMD_SWIZZLE_MASK (0x1F , 0 , 0x10 )));
362
398
x = fmaxf (x, hip_move_dppf (x, AMD_DPP_ROW_RR (8 ), 0xF , 0xF , false ));
@@ -428,13 +464,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
428
464
#if FP16_AVAILABLE
429
465
430
466
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
431
- #pragma unroll
432
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
433
- const half2 a_other = __shfl_xor_sync (0xffffffff , a, mask, 32 );
434
- reinterpret_cast <half&>(a.x ) += __low2half (a_other);
435
- reinterpret_cast <half&>(a.y ) += __high2half (a_other);
436
- }
437
- return a;
467
+ return warp_reduce_sum_impl_amd (a);
438
468
#else
439
469
#pragma unroll
440
470
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
0 commit comments