@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
443
#define CUDA_SCALE_BLOCK_SIZE 256
444
444
#define CUDA_CLAMP_BLOCK_SIZE 256
445
445
#define CUDA_ROPE_BLOCK_SIZE 256
446
+ #define CUDA_SOFT_MAX_BLOCK_SIZE 512
446
447
#define CUDA_ALIBI_BLOCK_SIZE 32
447
448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448
449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -4717,45 +4718,59 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4717
4718
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4718
4719
}
4719
4720
4720
- // the CUDA soft max implementation differs from the CPU implementation
4721
- // instead of doubles floats are used
4721
+ // TODO: maybe can be improved with some warp-based primitives
4722
4722
static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723
- const int rowx = blockDim .x *blockIdx .x + threadIdx .x ;
4723
+ const int tid = threadIdx .x ;
4724
+ const int rowx = blockIdx .x ;
4724
4725
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4725
- const int block_size = blockDim .y ;
4726
- const int tid = threadIdx .y ;
4727
4726
4728
- float max_val = -INFINITY;
4727
+ const int block_size = blockDim .x ;
4728
+
4729
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4730
+
4731
+ buf[tid] = -INFINITY;
4729
4732
4730
4733
for (int col = tid; col < ncols; col += block_size) {
4731
4734
const int ix = rowx*ncols + col;
4732
4735
const int iy = rowy*ncols + col;
4733
- max_val = max (max_val , x[ix]*scale + (y ? y[iy] : 0 .0f ));
4736
+ buf[tid] = max (buf[tid] , x[ix]*scale + (y ? y[iy] : 0 .0f ));
4734
4737
}
4735
4738
4739
+ __syncthreads ();
4740
+
4736
4741
// find the max value in the block
4737
- #pragma unroll
4738
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4739
- max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
4742
+ for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4743
+ if (tid < i) {
4744
+ buf[tid] = max (buf[tid], buf[tid + i]);
4745
+ }
4746
+ __syncthreads ();
4740
4747
}
4741
4748
4742
4749
float tmp = 0 .f ;
4743
4750
4744
4751
for (int col = tid; col < ncols; col += block_size) {
4745
4752
const int ix = rowx*ncols + col;
4746
4753
const int iy = rowy*ncols + col;
4747
- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val );
4754
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - buf[ 0 ] );
4748
4755
tmp += val;
4749
4756
dst[ix] = val;
4750
4757
}
4751
4758
4759
+ __syncthreads ();
4760
+
4761
+ buf[tid] = tmp;
4762
+
4763
+ __syncthreads ();
4764
+
4752
4765
// sum up partial sums
4753
- #pragma unroll
4754
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4755
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4766
+ for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4767
+ if (tid < i) {
4768
+ buf[tid] += buf[tid + i];
4769
+ }
4770
+ __syncthreads ();
4756
4771
}
4757
4772
4758
- const float inv_tmp = 1 .f / tmp ;
4773
+ const float inv_tmp = 1 .f / buf[ 0 ] ;
4759
4774
4760
4775
for (int col = tid; col < ncols; col += block_size) {
4761
4776
const int i = rowx*ncols + col;
@@ -5796,7 +5811,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5796
5811
}
5797
5812
5798
5813
static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5799
- const dim3 block_dims (1 , WARP_SIZE, 1 );
5814
+ int nth = WARP_SIZE;
5815
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5816
+ const dim3 block_dims (nth, 1 , 1 );
5800
5817
const dim3 block_nums (nrows_x, 1 , 1 );
5801
5818
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale);
5802
5819
}
@@ -6853,7 +6870,7 @@ inline void ggml_cuda_op_soft_max(
6853
6870
6854
6871
const int64_t ne00 = src0->ne [0 ];
6855
6872
const int64_t nrows_x = ggml_nrows (src0);
6856
- const int64_t nrows_y = src1 ? ggml_nrows (src1) : 0 ;
6873
+ const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
6857
6874
6858
6875
float scale = 1 .0f ;
6859
6876
memcpy (&scale, dst->op_params , sizeof (float ));
0 commit comments