Skip to content

Commit 20d7740

Browse files
authored
ggml : sync (abort callback, mul / add broadcast, fix alibi) (#2183)
1 parent 5bf2a27 commit 20d7740

File tree

5 files changed

+173
-72
lines changed

5 files changed

+173
-72
lines changed

ggml-cuda.cu

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
239239
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
240240
};
241241

242-
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
242+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
243243
const int i = blockDim.x*blockIdx.x + threadIdx.x;
244244

245-
if (i >= k) {
245+
if (i >= kx) {
246246
return;
247247
}
248-
dst[i] = x[i] + y[i];
248+
dst[i] = x[i] + y[i%ky];
249249
}
250250

251251
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
275275
dst[i] = x[i] / (1.0f + expf(-x[i]));
276276
}
277277

278+
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
279+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
280+
const int tid = threadIdx.x;
281+
282+
const float eps = 1e-5f;
283+
284+
float mean = 0.0f;
285+
float var = 0.0f;
286+
287+
for (int col = tid; col < ncols; col += WARP_SIZE) {
288+
const float xi = x[row*ncols + col];
289+
mean += xi;
290+
var += xi * xi;
291+
}
292+
293+
// sum up partial sums
294+
#pragma unroll
295+
for (int mask = 16; mask > 0; mask >>= 1) {
296+
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
297+
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
298+
}
299+
300+
mean /= ncols;
301+
var = var / ncols - mean * mean;
302+
const float inv_var = rsqrtf(var + eps);
303+
304+
for (int col = tid; col < ncols; col += WARP_SIZE) {
305+
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
306+
}
307+
}
308+
278309
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
279310
const int row = blockIdx.x*blockDim.y + threadIdx.y;
280311
const int tid = threadIdx.x;
281312

282-
const float eps = 1e-6;
313+
const float eps = 1e-6f;
283314

284315
float tmp = 0.0f; // partial sum for thread in warp
285316

286-
for (int i = 0; i < ncols; i += WARP_SIZE) {
287-
const int col = i + tid;
317+
for (int col = tid; col < ncols; col += WARP_SIZE) {
288318
const float xi = x[row*ncols + col];
289319
tmp += xi * xi;
290320
}
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
296326
}
297327

298328
const float mean = tmp / ncols;
299-
const float scale = 1.0f / sqrtf(mean + eps);
329+
const float scale = rsqrtf(mean + eps);
300330

301-
for (int i = 0; i < ncols; i += WARP_SIZE) {
302-
const int col = i + tid;
331+
for (int col = tid; col < ncols; col += WARP_SIZE) {
303332
dst[row*ncols + col] = scale * x[row*ncols + col];
304333
}
305334
}
@@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
16891718
dst[i] = scale * x[i];
16901719
}
16911720

1692-
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
1693-
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1694-
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1721+
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
1722+
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1723+
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
16951724
}
16961725

16971726
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
17091738
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
17101739
}
17111740

1741+
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1742+
GGML_ASSERT(ncols % WARP_SIZE == 0);
1743+
const dim3 block_dims(WARP_SIZE, 1, 1);
1744+
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
1745+
}
1746+
17121747
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
17131748
GGML_ASSERT(ncols % WARP_SIZE == 0);
17141749
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
22392274
GGML_ASSERT(src1_ddf_i != nullptr);
22402275
GGML_ASSERT(dst_ddf_i != nullptr);
22412276

2242-
const int64_t ne0 = src0->ne[0];
2277+
const int64_t ne00 = src0->ne[0];
22432278
const int64_t i01_diff = i01_high - i01_low;
22442279

2280+
const int64_t ne10 = src1->ne[0];
2281+
22452282
// compute
22462283
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2247-
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
2284+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
22482285
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
2249-
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
2286+
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
22502287
} else {
22512288
GGML_ASSERT(false);
22522289
}
@@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
22682305
GGML_ASSERT(dst_ddf_i != nullptr);
22692306

22702307
const int64_t ne00 = src0->ne[0];
2308+
const int64_t i01_diff = i01_high - i01_low;
22712309

22722310
const int64_t ne10 = src1->ne[0];
2273-
const int64_t ne11 = src1->ne[1];
2274-
2275-
for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
2276-
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
22772311

2278-
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
2279-
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
2280-
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
2281-
2282-
// compute
2283-
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2284-
}
2312+
mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
22852313

22862314
(void) dst;
22872315
(void) src0_ddq_i;
@@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
23102338
(void) i1;
23112339
}
23122340

2341+
inline void ggml_cuda_op_norm(
2342+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2343+
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
2344+
cudaStream_t & cudaStream_main){
2345+
2346+
GGML_ASSERT(src0_ddf_i != nullptr);
2347+
GGML_ASSERT(dst_ddf_i != nullptr);
2348+
2349+
const int64_t ne00 = src0->ne[0];
2350+
const int64_t i01_diff = i01_high - i01_low;
2351+
2352+
// compute
2353+
norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2354+
2355+
(void) src1;
2356+
(void) dst;
2357+
(void) src0_ddq_i;
2358+
(void) src1_ddf_i;
2359+
(void) i02;
2360+
(void) i1;
2361+
}
2362+
23132363
inline void ggml_cuda_op_rms_norm(
23142364
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
23152365
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
29302980
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
29312981
}
29322982

2983+
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2984+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2985+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
2986+
}
2987+
29332988
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29342989
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
29352990
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
31603215
}
31613216

31623217

3163-
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
3218+
CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
31643219

31653220
extra->data_device[id] = buf;
31663221

@@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
33223377
}
33233378
func = ggml_cuda_silu;
33243379
break;
3380+
case GGML_OP_NORM:
3381+
if (!any_on_device) {
3382+
return false;
3383+
}
3384+
func = ggml_cuda_norm;
3385+
break;
33253386
case GGML_OP_RMS_NORM:
33263387
if (!any_on_device) {
33273388
return false;

0 commit comments

Comments
 (0)