@@ -443,7 +443,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443#define CUDA_SCALE_BLOCK_SIZE 256
444444#define CUDA_CLAMP_BLOCK_SIZE 256
445445#define CUDA_ROPE_BLOCK_SIZE 256
446- #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
447446#define CUDA_ALIBI_BLOCK_SIZE 32
448447#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
449448#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -503,31 +502,6 @@ static size_t g_scratch_offset = 0;
503502
504503static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
505504
506- static __device__ __forceinline__ float warp_reduce_sum (float x) {
507- #pragma unroll
508- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
509- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
510- }
511- return x;
512- }
513-
514- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
515- #pragma unroll
516- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
517- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
518- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
519- }
520- return a;
521- }
522-
523- static __device__ __forceinline__ float warp_reduce_max (float x) {
524- #pragma unroll
525- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
526- x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
527- }
528- return x;
529- }
530-
531505static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
532506 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
533507
@@ -604,6 +578,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
604578 dst[i] = x[i] * x[i];
605579}
606580
581+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
582+ #pragma unroll
583+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
584+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
585+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
586+ }
587+ return a;
588+ }
589+
607590template <int block_size>
608591static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
609592 const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -642,6 +625,14 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
642625 }
643626}
644627
628+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
629+ #pragma unroll
630+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
631+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
632+ }
633+ return x;
634+ }
635+
645636template <int block_size>
646637static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
647638 const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4727,74 +4718,45 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47274718 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47284719}
47294720
4730- static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4731- const int tid = threadIdx .x ;
4732- const int rowx = blockIdx .x ;
4733- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4734-
4735- const int block_size = blockDim .x ;
4736-
4737- const int warp_id = threadIdx .x / WARP_SIZE;
4738- const int lane_id = threadIdx .x % WARP_SIZE;
4739-
4740- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4721+ // the CUDA soft max implementation differs from the CPU implementation
4722+ // instead of doubles floats are used
4723+ static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4724+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4725+ const int block_size = blockDim .y ;
4726+ const int tid = threadIdx .y ;
47414727
47424728 float max_val = -INFINITY;
47434729
47444730 for (int col = tid; col < ncols; col += block_size) {
4745- const int ix = rowx*ncols + col;
4746- const int iy = rowy*ncols + col;
4747- max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
4731+ const int i = row*ncols + col;
4732+ max_val = max (max_val, x[i]);
47484733 }
47494734
47504735 // find the max value in the block
4751- max_val = warp_reduce_max (max_val);
4752- if (block_size > WARP_SIZE) {
4753- if (warp_id == 0 ) {
4754- buf[lane_id] = -INFINITY;
4755- }
4756- __syncthreads ();
4757-
4758- if (lane_id == 0 ) {
4759- buf[warp_id] = max_val;
4760- }
4761- __syncthreads ();
4762-
4763- max_val = buf[lane_id];
4764- max_val = warp_reduce_max (max_val);
4736+ #pragma unroll
4737+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4738+ max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
47654739 }
47664740
47674741 float tmp = 0 .f ;
47684742
47694743 for (int col = tid; col < ncols; col += block_size) {
4770- const int ix = rowx*ncols + col;
4771- const int iy = rowy*ncols + col;
4772- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
4744+ const int i = row*ncols + col;
4745+ const float val = expf (x[i] - max_val);
47734746 tmp += val;
4774- dst[ix ] = val;
4747+ dst[i ] = val;
47754748 }
47764749
4777- // find the sum of exps in the block
4778- tmp = warp_reduce_sum (tmp);
4779- if (block_size > WARP_SIZE) {
4780- if (warp_id == 0 ) {
4781- buf[lane_id] = 0 .f ;
4782- }
4783- __syncthreads ();
4784-
4785- if (lane_id == 0 ) {
4786- buf[warp_id] = tmp;
4787- }
4788- __syncthreads ();
4789-
4790- tmp = buf[lane_id];
4791- tmp = warp_reduce_sum (tmp);
4750+ // sum up partial sums
4751+ #pragma unroll
4752+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4753+ tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
47924754 }
47934755
47944756 const float inv_tmp = 1 .f / tmp;
47954757
47964758 for (int col = tid; col < ncols; col += block_size) {
4797- const int i = rowx *ncols + col;
4759+ const int i = row *ncols + col;
47984760 dst[i] *= inv_tmp;
47994761 }
48004762}
@@ -5831,12 +5793,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
58315793 diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
58325794}
58335795
5834- static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5835- int nth = WARP_SIZE;
5836- while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5837- const dim3 block_dims (nth, 1 , 1 );
5796+ static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5797+ const dim3 block_dims (1 , WARP_SIZE, 1 );
58385798 const dim3 block_nums (nrows_x, 1 , 1 );
5839- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
5799+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
58405800}
58415801
58425802static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6875,18 +6835,14 @@ inline void ggml_cuda_op_soft_max(
68756835 GGML_ASSERT (src0->type == GGML_TYPE_F32);
68766836 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
68776837
6878- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6879-
68806838 const int64_t ne00 = src0->ne [0 ];
6881- const int64_t nrows_x = ggml_nrows (src0);
6882- const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
6883-
6884- float scale = 1 .0f ;
6885- memcpy (&scale, dst->op_params , sizeof (float ));
6839+ const int64_t nrows = ggml_nrows (src0);
68866840
6887- soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale , main_stream);
6841+ soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows , main_stream);
68886842
6843+ (void ) src1;
68896844 (void ) dst;
6845+ (void ) src1_dd;
68906846}
68916847
68926848inline void ggml_cuda_op_scale (
0 commit comments