@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
239
239
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
240
240
};
241
241
242
- static __global__ void add_f32 (const float * x, const float * y, float * dst, const int k ) {
242
+ static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky ) {
243
243
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
244
244
245
- if (i >= k ) {
245
+ if (i >= kx ) {
246
246
return ;
247
247
}
248
- dst[i] = x[i] + y[i];
248
+ dst[i] = x[i] + y[i%ky ];
249
249
}
250
250
251
251
static __global__ void add_f16_f32_f16 (const half * x, const float * y, half * dst, const int k) {
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
275
275
dst[i] = x[i] / (1 .0f + expf (-x[i]));
276
276
}
277
277
278
+ static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
279
+ const int row = blockIdx .x *blockDim .y + threadIdx .y ;
280
+ const int tid = threadIdx .x ;
281
+
282
+ const float eps = 1e-5f ;
283
+
284
+ float mean = 0 .0f ;
285
+ float var = 0 .0f ;
286
+
287
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
288
+ const float xi = x[row*ncols + col];
289
+ mean += xi;
290
+ var += xi * xi;
291
+ }
292
+
293
+ // sum up partial sums
294
+ #pragma unroll
295
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
296
+ mean += __shfl_xor_sync (0xffffffff , mean, mask, 32 );
297
+ var += __shfl_xor_sync (0xffffffff , var, mask, 32 );
298
+ }
299
+
300
+ mean /= ncols;
301
+ var = var / ncols - mean * mean;
302
+ const float inv_var = rsqrtf (var + eps);
303
+
304
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
305
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
306
+ }
307
+ }
308
+
278
309
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols) {
279
310
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
280
311
const int tid = threadIdx .x ;
281
312
282
- const float eps = 1e-6 ;
313
+ const float eps = 1e-6f ;
283
314
284
315
float tmp = 0 .0f ; // partial sum for thread in warp
285
316
286
- for (int i = 0 ; i < ncols; i += WARP_SIZE) {
287
- const int col = i + tid;
317
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
288
318
const float xi = x[row*ncols + col];
289
319
tmp += xi * xi;
290
320
}
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
296
326
}
297
327
298
328
const float mean = tmp / ncols;
299
- const float scale = 1 . 0f / sqrtf (mean + eps);
329
+ const float scale = rsqrtf (mean + eps);
300
330
301
- for (int i = 0 ; i < ncols; i += WARP_SIZE) {
302
- const int col = i + tid;
331
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
303
332
dst[row*ncols + col] = scale * x[row*ncols + col];
304
333
}
305
334
}
@@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
1689
1718
dst[i] = scale * x[i];
1690
1719
}
1691
1720
1692
- static void add_f32_cuda (const float * x, const float * y, float * dst, const int k , cudaStream_t stream) {
1693
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1 ) / CUDA_ADD_BLOCK_SIZE;
1694
- add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0 , stream>>> (x, y, dst, k );
1721
+ static void add_f32_cuda (const float * x, const float * y, float * dst, const int kx, const int ky , cudaStream_t stream) {
1722
+ const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1 ) / CUDA_ADD_BLOCK_SIZE;
1723
+ add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0 , stream>>> (x, y, dst, kx, ky );
1695
1724
}
1696
1725
1697
1726
static void add_f16_f32_f16_cuda (const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
1709
1738
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0 , stream>>> (x, dst, k);
1710
1739
}
1711
1740
1741
+ static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1742
+ GGML_ASSERT (ncols % WARP_SIZE == 0 );
1743
+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
1744
+ norm_f32<<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
1745
+ }
1746
+
1712
1747
static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1713
1748
GGML_ASSERT (ncols % WARP_SIZE == 0 );
1714
1749
const dim3 block_dims (WARP_SIZE, 1 , 1 );
@@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
2239
2274
GGML_ASSERT (src1_ddf_i != nullptr );
2240
2275
GGML_ASSERT (dst_ddf_i != nullptr );
2241
2276
2242
- const int64_t ne0 = src0->ne [0 ];
2277
+ const int64_t ne00 = src0->ne [0 ];
2243
2278
const int64_t i01_diff = i01_high - i01_low;
2244
2279
2280
+ const int64_t ne10 = src1->ne [0 ];
2281
+
2245
2282
// compute
2246
2283
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2247
- add_f32_cuda (src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0 *i01_diff, cudaStream_main);
2284
+ add_f32_cuda (src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00 *i01_diff, ne10 , cudaStream_main);
2248
2285
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
2249
- add_f16_f32_f16_cuda ((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0 *i01_diff, cudaStream_main);
2286
+ add_f16_f32_f16_cuda ((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00 *i01_diff, cudaStream_main);
2250
2287
} else {
2251
2288
GGML_ASSERT (false );
2252
2289
}
@@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
2268
2305
GGML_ASSERT (dst_ddf_i != nullptr );
2269
2306
2270
2307
const int64_t ne00 = src0->ne [0 ];
2308
+ const int64_t i01_diff = i01_high - i01_low;
2271
2309
2272
2310
const int64_t ne10 = src1->ne [0 ];
2273
- const int64_t ne11 = src1->ne [1 ];
2274
-
2275
- for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
2276
- const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
2277
2311
2278
- float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
2279
- float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
2280
- float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
2281
-
2282
- // compute
2283
- mul_f32_cuda (src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
2284
- }
2312
+ mul_f32_cuda (src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
2285
2313
2286
2314
(void ) dst;
2287
2315
(void ) src0_ddq_i;
@@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
2310
2338
(void ) i1;
2311
2339
}
2312
2340
2341
+ inline void ggml_cuda_op_norm (
2342
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2343
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
2344
+ cudaStream_t & cudaStream_main){
2345
+
2346
+ GGML_ASSERT (src0_ddf_i != nullptr );
2347
+ GGML_ASSERT (dst_ddf_i != nullptr );
2348
+
2349
+ const int64_t ne00 = src0->ne [0 ];
2350
+ const int64_t i01_diff = i01_high - i01_low;
2351
+
2352
+ // compute
2353
+ norm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
2354
+
2355
+ (void ) src1;
2356
+ (void ) dst;
2357
+ (void ) src0_ddq_i;
2358
+ (void ) src1_ddf_i;
2359
+ (void ) i02;
2360
+ (void ) i1;
2361
+ }
2362
+
2313
2363
inline void ggml_cuda_op_rms_norm (
2314
2364
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2315
2365
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
2930
2980
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_silu, true , true );
2931
2981
}
2932
2982
2983
+ void ggml_cuda_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2984
+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2985
+ ggml_cuda_op (src0, src1, dst, ggml_cuda_op_norm, true , true );
2986
+ }
2987
+
2933
2988
void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2934
2989
GGML_ASSERT (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2935
2990
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_rms_norm, true , true );
@@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
3160
3215
}
3161
3216
3162
3217
3163
- cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice);
3218
+ CUDA_CHECK ( cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice) );
3164
3219
3165
3220
extra->data_device [id] = buf;
3166
3221
@@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
3322
3377
}
3323
3378
func = ggml_cuda_silu;
3324
3379
break ;
3380
+ case GGML_OP_NORM:
3381
+ if (!any_on_device) {
3382
+ return false ;
3383
+ }
3384
+ func = ggml_cuda_norm;
3385
+ break ;
3325
3386
case GGML_OP_RMS_NORM:
3326
3387
if (!any_on_device) {
3327
3388
return false ;
0 commit comments