Skip to content

Commit 62532c0

Browse files
committed
cuda : do warp-based block reduce
1 parent c7c8dab commit 62532c0

File tree

1 file changed

+48
-39
lines changed

1 file changed

+48
-39
lines changed

ggml-cuda.cu

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,31 @@ static size_t g_scratch_offset = 0;
502502

503503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
504504

505+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
506+
#pragma unroll
507+
for (int mask = 16; mask > 0; mask >>= 1) {
508+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
509+
}
510+
return x;
511+
}
512+
513+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
514+
#pragma unroll
515+
for (int mask = 16; mask > 0; mask >>= 1) {
516+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
517+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
518+
}
519+
return a;
520+
}
521+
522+
static __device__ __forceinline__ float warp_reduce_max(float x) {
523+
#pragma unroll
524+
for (int mask = 16; mask > 0; mask >>= 1) {
525+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
526+
}
527+
return x;
528+
}
529+
505530
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
506531
const int i = blockDim.x*blockIdx.x + threadIdx.x;
507532

@@ -578,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
578603
dst[i] = x[i] * x[i];
579604
}
580605

581-
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
582-
#pragma unroll
583-
for (int mask = 16; mask > 0; mask >>= 1) {
584-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
585-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
586-
}
587-
return a;
588-
}
589-
590606
template <int block_size>
591607
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
592608
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -625,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
625641
}
626642
}
627643

628-
static __device__ __forceinline__ float warp_reduce_sum(float x) {
629-
#pragma unroll
630-
for (int mask = 16; mask > 0; mask >>= 1) {
631-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
632-
}
633-
return x;
634-
}
635-
636644
template <int block_size>
637645
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
638646
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4718,59 +4726,60 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47184726
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47194727
}
47204728

4721-
// TODO: maybe can be improved with some warp-based primitives
47224729
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
47234730
const int tid = threadIdx.x;
47244731
const int rowx = blockIdx.x;
47254732
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
47264733

47274734
const int block_size = blockDim.x;
47284735

4729-
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4730-
4731-
buf[tid] = -INFINITY;
4736+
float max_val = -INFINITY;
47324737

47334738
for (int col = tid; col < ncols; col += block_size) {
47344739
const int ix = rowx*ncols + col;
47354740
const int iy = rowy*ncols + col;
4736-
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
4741+
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
47374742
}
47384743

4739-
__syncthreads();
4740-
47414744
// find the max value in the block
4742-
for (int i = block_size/2; i > 0; i >>= 1) {
4743-
if (tid < i) {
4744-
buf[tid] = max(buf[tid], buf[tid + i]);
4745+
max_val = warp_reduce_max(max_val);
4746+
if (block_size > WARP_SIZE) {
4747+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4748+
int warp_id = threadIdx.x / WARP_SIZE;
4749+
int lane_id = threadIdx.x % WARP_SIZE;
4750+
if (lane_id == 0) {
4751+
buf[warp_id] = max_val;
47454752
}
47464753
__syncthreads();
4754+
max_val = buf[lane_id];
4755+
max_val = warp_reduce_max(max_val);
47474756
}
47484757

47494758
float tmp = 0.f;
47504759

47514760
for (int col = tid; col < ncols; col += block_size) {
47524761
const int ix = rowx*ncols + col;
47534762
const int iy = rowy*ncols + col;
4754-
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
4763+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
47554764
tmp += val;
47564765
dst[ix] = val;
47574766
}
47584767

4759-
__syncthreads();
4760-
4761-
buf[tid] = tmp;
4762-
4763-
__syncthreads();
4764-
4765-
// sum up partial sums
4766-
for (int i = block_size/2; i > 0; i >>= 1) {
4767-
if (tid < i) {
4768-
buf[tid] += buf[tid + i];
4768+
// find the sum of exps in the block
4769+
tmp = warp_reduce_sum(tmp);
4770+
if (block_size > WARP_SIZE) {
4771+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4772+
int warp_id = threadIdx.x / WARP_SIZE;
4773+
int lane_id = threadIdx.x % WARP_SIZE;
4774+
if (lane_id == 0) {
4775+
buf[warp_id] = tmp;
47694776
}
47704777
__syncthreads();
4778+
tmp = buf[lane_id];
4779+
tmp = warp_reduce_sum(tmp);
47714780
}
47724781

4773-
const float inv_tmp = 1.f / buf[0];
4782+
const float inv_tmp = 1.f / tmp;
47744783

47754784
for (int col = tid; col < ncols; col += block_size) {
47764785
const int i = rowx*ncols + col;

0 commit comments

Comments
 (0)