Skip to content

Commit ebd062b

Browse files
committed
cuda : use 512 threads for soft_max instead of 32
1 parent 580fe20 commit ebd062b

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

ggml-cuda.cu

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446+
#define CUDA_SOFT_MAX_BLOCK_SIZE 512
446447
#define CUDA_ALIBI_BLOCK_SIZE 32
447448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -4717,45 +4718,59 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47174718
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47184719
}
47194720

4720-
// the CUDA soft max implementation differs from the CPU implementation
4721-
// instead of doubles floats are used
4721+
// TODO: maybe can be improved with some warp-based primitives
47224722
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723-
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
4723+
const int tid = threadIdx.x;
4724+
const int rowx = blockIdx.x;
47244725
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4725-
const int block_size = blockDim.y;
4726-
const int tid = threadIdx.y;
47274726

4728-
float max_val = -INFINITY;
4727+
const int block_size = blockDim.x;
4728+
4729+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4730+
4731+
buf[tid] = -INFINITY;
47294732

47304733
for (int col = tid; col < ncols; col += block_size) {
47314734
const int ix = rowx*ncols + col;
47324735
const int iy = rowy*ncols + col;
4733-
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
4736+
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
47344737
}
47354738

4739+
__syncthreads();
4740+
47364741
// find the max value in the block
4737-
#pragma unroll
4738-
for (int mask = 16; mask > 0; mask >>= 1) {
4739-
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
4742+
for (int i = block_size/2; i > 0; i >>= 1) {
4743+
if (tid < i) {
4744+
buf[tid] = max(buf[tid], buf[tid + i]);
4745+
}
4746+
__syncthreads();
47404747
}
47414748

47424749
float tmp = 0.f;
47434750

47444751
for (int col = tid; col < ncols; col += block_size) {
47454752
const int ix = rowx*ncols + col;
47464753
const int iy = rowy*ncols + col;
4747-
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
4754+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
47484755
tmp += val;
47494756
dst[ix] = val;
47504757
}
47514758

4759+
__syncthreads();
4760+
4761+
buf[tid] = tmp;
4762+
4763+
__syncthreads();
4764+
47524765
// sum up partial sums
4753-
#pragma unroll
4754-
for (int mask = 16; mask > 0; mask >>= 1) {
4755-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
4766+
for (int i = block_size/2; i > 0; i >>= 1) {
4767+
if (tid < i) {
4768+
buf[tid] += buf[tid + i];
4769+
}
4770+
__syncthreads();
47564771
}
47574772

4758-
const float inv_tmp = 1.f / tmp;
4773+
const float inv_tmp = 1.f / buf[0];
47594774

47604775
for (int col = tid; col < ncols; col += block_size) {
47614776
const int i = rowx*ncols + col;
@@ -5796,7 +5811,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57965811
}
57975812

57985813
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5799-
const dim3 block_dims(1, WARP_SIZE, 1);
5814+
int nth = WARP_SIZE;
5815+
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5816+
const dim3 block_dims(nth, 1, 1);
58005817
const dim3 block_nums(nrows_x, 1, 1);
58015818
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
58025819
}
@@ -6853,7 +6870,7 @@ inline void ggml_cuda_op_soft_max(
68536870

68546871
const int64_t ne00 = src0->ne[0];
68556872
const int64_t nrows_x = ggml_nrows(src0);
6856-
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
6873+
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
68576874

68586875
float scale = 1.0f;
68596876
memcpy(&scale, dst->op_params, sizeof(float));

0 commit comments

Comments
 (0)