@@ -502,6 +502,31 @@ static size_t g_scratch_offset = 0;
502
502
503
503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
504
504
505
+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
506
+ #pragma unroll
507
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
508
+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
509
+ }
510
+ return x;
511
+ }
512
+
513
+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
514
+ #pragma unroll
515
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
516
+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
517
+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
518
+ }
519
+ return a;
520
+ }
521
+
522
+ static __device__ __forceinline__ float warp_reduce_max (float x) {
523
+ #pragma unroll
524
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
525
+ x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
526
+ }
527
+ return x;
528
+ }
529
+
505
530
static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
506
531
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
507
532
@@ -578,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
578
603
dst[i] = x[i] * x[i];
579
604
}
580
605
581
- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
582
- #pragma unroll
583
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
584
- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
585
- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
586
- }
587
- return a;
588
- }
589
-
590
606
template <int block_size>
591
607
static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
592
608
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -625,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
625
641
}
626
642
}
627
643
628
- static __device__ __forceinline__ float warp_reduce_sum (float x) {
629
- #pragma unroll
630
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
631
- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
632
- }
633
- return x;
634
- }
635
-
636
644
template <int block_size>
637
645
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
638
646
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4718,59 +4726,60 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4718
4726
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4719
4727
}
4720
4728
4721
- // TODO: maybe can be improved with some warp-based primitives
4722
4729
static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723
4730
const int tid = threadIdx .x ;
4724
4731
const int rowx = blockIdx .x ;
4725
4732
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4726
4733
4727
4734
const int block_size = blockDim .x ;
4728
4735
4729
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4730
-
4731
- buf[tid] = -INFINITY;
4736
+ float max_val = -INFINITY;
4732
4737
4733
4738
for (int col = tid; col < ncols; col += block_size) {
4734
4739
const int ix = rowx*ncols + col;
4735
4740
const int iy = rowy*ncols + col;
4736
- buf[tid] = max (buf[tid] , x[ix]*scale + (y ? y[iy] : 0 .0f ));
4741
+ max_val = max (max_val , x[ix]*scale + (y ? y[iy] : 0 .0f ));
4737
4742
}
4738
4743
4739
- __syncthreads ();
4740
-
4741
4744
// find the max value in the block
4742
- for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4743
- if (tid < i) {
4744
- buf[tid] = max (buf[tid], buf[tid + i]);
4745
+ max_val = warp_reduce_max (max_val);
4746
+ if (block_size > WARP_SIZE) {
4747
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4748
+ int warp_id = threadIdx .x / WARP_SIZE;
4749
+ int lane_id = threadIdx .x % WARP_SIZE;
4750
+ if (lane_id == 0 ) {
4751
+ buf[warp_id] = max_val;
4745
4752
}
4746
4753
__syncthreads ();
4754
+ max_val = buf[lane_id];
4755
+ max_val = warp_reduce_max (max_val);
4747
4756
}
4748
4757
4749
4758
float tmp = 0 .f ;
4750
4759
4751
4760
for (int col = tid; col < ncols; col += block_size) {
4752
4761
const int ix = rowx*ncols + col;
4753
4762
const int iy = rowy*ncols + col;
4754
- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - buf[ 0 ] );
4763
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val );
4755
4764
tmp += val;
4756
4765
dst[ix] = val;
4757
4766
}
4758
4767
4759
- __syncthreads ();
4760
-
4761
- buf[tid] = tmp;
4762
-
4763
- __syncthreads ();
4764
-
4765
- // sum up partial sums
4766
- for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4767
- if (tid < i) {
4768
- buf[tid] += buf[tid + i];
4768
+ // find the sum of exps in the block
4769
+ tmp = warp_reduce_sum (tmp);
4770
+ if (block_size > WARP_SIZE) {
4771
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4772
+ int warp_id = threadIdx .x / WARP_SIZE;
4773
+ int lane_id = threadIdx .x % WARP_SIZE;
4774
+ if (lane_id == 0 ) {
4775
+ buf[warp_id] = tmp;
4769
4776
}
4770
4777
__syncthreads ();
4778
+ tmp = buf[lane_id];
4779
+ tmp = warp_reduce_sum (tmp);
4771
4780
}
4772
4781
4773
- const float inv_tmp = 1 .f / buf[ 0 ] ;
4782
+ const float inv_tmp = 1 .f / tmp ;
4774
4783
4775
4784
for (int col = tid; col < ncols; col += block_size) {
4776
4785
const int i = rowx*ncols + col;
0 commit comments