Skip to content

Commit 835d702

Browse files
committed
Fix more int overflow during quant.
1 parent ea1aeba commit 835d702

File tree

2 files changed

+85
-85
lines changed

2 files changed

+85
-85
lines changed

ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12251225

12261226
// the main device has a larger memory buffer to hold the results from all GPUs
12271227
// ldc == nrows of the matrix that cuBLAS writes into
1228-
int64_t ldc = id == ctx.device ? ne0 : row_diff;
1228+
int ldc = id == ctx.device ? ne0 : row_diff;
12291229

12301230
const int compute_capability = ggml_cuda_info().devices[id].cc;
12311231

ggml-cuda/convert.cu

Lines changed: 84 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
77
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8-
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
8+
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
99

1010
if (i >= k) {
1111
return;
1212
}
1313

1414
const int64_t ib = i/qk; // block index
15-
const int iqs = (i%qk)/qr; // quant index
16-
const int iybs = i - i%qk; // y block start index
17-
const int y_offset = qr == 1 ? 1 : qk/2;
15+
const int64_t iqs = (i%qk)/qr; // quant index
16+
const int64_t iybs = i - i%qk; // y block start index
17+
const int64_t y_offset = qr == 1 ? 1 : qk/2;
1818

1919
// dequantize
2020
dfloat2 v;
@@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
2929
#if __CUDA_ARCH__ >= CC_PASCAL
3030
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
3131

32-
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
32+
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
3333
const int * x0 = ((int *) vx) + blockIdx.x * nint;
3434
half2 * y2 = (half2 *) (y + i0);
3535

@@ -71,9 +71,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
7171
const int64_t i = blockIdx.x;
7272

7373
// assume 32 threads
74-
const int tid = threadIdx.x;
75-
const int il = tid/8;
76-
const int ir = tid%8;
74+
const int64_t tid = threadIdx.x;
75+
const int64_t il = tid/8;
76+
const int64_t ir = tid%8;
7777
const int64_t ib = 8*i + ir;
7878
if (ib >= nb32) {
7979
return;
@@ -99,9 +99,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
9999
const int64_t i = blockIdx.x;
100100

101101
// assume 32 threads
102-
const int tid = threadIdx.x;
103-
const int il = tid/8;
104-
const int ir = tid%8;
102+
const int64_t tid = threadIdx.x;
103+
const int64_t il = tid/8;
104+
const int64_t ir = tid%8;
105105
const int64_t ib = 8*i + ir;
106106
if (ib >= nb32) {
107107
return;
@@ -125,14 +125,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
125125
template<typename dst_t>
126126
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
127127

128-
const int i = blockIdx.x;
128+
const int64_t i = blockIdx.x;
129129
const block_q2_K * x = (const block_q2_K *) vx;
130130

131-
const int tid = threadIdx.x;
131+
const int64_t tid = threadIdx.x;
132132
#if QK_K == 256
133-
const int n = tid/32;
134-
const int l = tid - 32*n;
135-
const int is = 8*n + l/16;
133+
const int64_t n = tid/32;
134+
const int64_t l = tid - 32*n;
135+
const int64_t is = 8*n + l/16;
136136

137137
const uint8_t q = x[i].qs[32*n + l];
138138
dst_t * y = yy + i*QK_K + 128*n;
@@ -144,8 +144,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
144144
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
145145
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
146146
#else
147-
const int is = tid/16; // 0 or 1
148-
const int il = tid%16; // 0...15
147+
const int64_t is = tid/16; // 0 or 1
148+
const int64_t il = tid%16; // 0...15
149149
const uint8_t q = x[i].qs[il] >> (2*is);
150150
dst_t * y = yy + i*QK_K + 16*is + il;
151151
float dall = __low2half(x[i].dm);
@@ -159,19 +159,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
159159
template<typename dst_t>
160160
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
161161

162-
const int i = blockIdx.x;
162+
const int64_t i = blockIdx.x;
163163
const block_q3_K * x = (const block_q3_K *) vx;
164164

165165
#if QK_K == 256
166-
const int r = threadIdx.x/4;
167-
const int tid = r/2;
168-
const int is0 = r%2;
169-
const int l0 = 16*is0 + 4*(threadIdx.x%4);
170-
const int n = tid / 4;
171-
const int j = tid - 4*n;
166+
const int64_t r = threadIdx.x/4;
167+
const int64_t tid = r/2;
168+
const int64_t is0 = r%2;
169+
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
170+
const int64_t n = tid / 4;
171+
const int64_t j = tid - 4*n;
172172

173173
uint8_t m = 1 << (4*n + j);
174-
int is = 8*n + 2*j + is0;
174+
int64_t is = 8*n + 2*j + is0;
175175
int shift = 2*j;
176176

177177
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@@ -187,11 +187,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
187187

188188
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
189189
#else
190-
const int tid = threadIdx.x;
191-
const int is = tid/16; // 0 or 1
192-
const int il = tid%16; // 0...15
193-
const int im = il/8; // 0...1
194-
const int in = il%8; // 0...7
190+
const int64_t tid = threadIdx.x;
191+
const int64_t is = tid/16; // 0 or 1
192+
const int64_t il = tid%16; // 0...15
193+
const int64_t im = il/8; // 0...1
194+
const int64_t in = il%8; // 0...7
195195

196196
dst_t * y = yy + i*QK_K + 16*is + il;
197197

@@ -225,15 +225,15 @@ template<typename dst_t>
225225
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
226226
const block_q4_K * x = (const block_q4_K *) vx;
227227

228-
const int i = blockIdx.x;
228+
const int64_t i = blockIdx.x;
229229

230230
#if QK_K == 256
231231
// assume 32 threads
232-
const int tid = threadIdx.x;
233-
const int il = tid/8;
234-
const int ir = tid%8;
235-
const int is = 2*il;
236-
const int n = 4;
232+
const int64_t tid = threadIdx.x;
233+
const int64_t il = tid/8;
234+
const int64_t ir = tid%8;
235+
const int64_t is = 2*il;
236+
const int64_t n = 4;
237237

238238
dst_t * y = yy + i*QK_K + 64*il + n*ir;
239239

@@ -252,7 +252,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
252252
y[l +32] = d2 * (q[l] >> 4) - m2;
253253
}
254254
#else
255-
const int tid = threadIdx.x;
255+
const int64_t tid = threadIdx.x;
256256
const uint8_t * q = x[i].qs;
257257
dst_t * y = yy + i*QK_K;
258258
const float d = (float)x[i].dm[0];
@@ -266,14 +266,14 @@ template<typename dst_t>
266266
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
267267
const block_q5_K * x = (const block_q5_K *) vx;
268268

269-
const int i = blockIdx.x;
269+
const int64_t i = blockIdx.x;
270270

271271
#if QK_K == 256
272272
// assume 64 threads - this is very slightly better than the one below
273-
const int tid = threadIdx.x;
274-
const int il = tid/16; // il is in 0...3
275-
const int ir = tid%16; // ir is in 0...15
276-
const int is = 2*il; // is is in 0...6
273+
const int64_t tid = threadIdx.x;
274+
const int64_t il = tid/16; // il is in 0...3
275+
const int64_t ir = tid%16; // ir is in 0...15
276+
const int64_t is = 2*il; // is is in 0...6
277277

278278
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
279279

@@ -296,11 +296,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
296296
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
297297
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
298298
#else
299-
const int tid = threadIdx.x;
299+
const int64_t tid = threadIdx.x;
300300
const uint8_t q = x[i].qs[tid];
301-
const int im = tid/8; // 0...3
302-
const int in = tid%8; // 0...7
303-
const int is = tid/16; // 0 or 1
301+
const int64_t im = tid/8; // 0...3
302+
const int64_t in = tid%8; // 0...7
303+
const int64_t is = tid/16; // 0 or 1
304304
const uint8_t h = x[i].qh[in] >> im;
305305
const float d = x[i].d;
306306
dst_t * y = yy + i*QK_K + tid;
@@ -357,13 +357,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
357357
template<typename dst_t>
358358
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
359359

360-
const int i = blockIdx.x;
360+
const int64_t i = blockIdx.x;
361361
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
362362

363-
const int tid = threadIdx.x;
363+
const int64_t tid = threadIdx.x;
364364
#if QK_K == 256
365-
const int il = tid/8; // 0...3
366-
const int ib = tid%8; // 0...7
365+
const int64_t il = tid/8; // 0...3
366+
const int64_t ib = tid%8; // 0...7
367367
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
368368
const uint16_t * q2 = x[i].qs + 4*ib;
369369
const uint8_t * aux8 = (const uint8_t *)q2;
@@ -381,13 +381,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
381381
template<typename dst_t>
382382
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
383383

384-
const int i = blockIdx.x;
384+
const int64_t i = blockIdx.x;
385385
const block_iq2_xs * x = (const block_iq2_xs *) vx;
386386

387-
const int tid = threadIdx.x;
387+
const int64_t tid = threadIdx.x;
388388
#if QK_K == 256
389-
const int il = tid/8; // 0...3
390-
const int ib = tid%8; // 0...7
389+
const int64_t il = tid/8; // 0...3
390+
const int64_t ib = tid%8; // 0...7
391391
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
392392
const uint16_t * q2 = x[i].qs + 4*ib;
393393
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@@ -403,13 +403,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
403403
template<typename dst_t>
404404
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
405405

406-
const int i = blockIdx.x;
406+
const int64_t i = blockIdx.x;
407407
const block_iq2_s * x = (const block_iq2_s *) vx;
408408

409-
const int tid = threadIdx.x;
409+
const int64_t tid = threadIdx.x;
410410
#if QK_K == 256
411-
const int il = tid/8; // 0...3
412-
const int ib = tid%8; // 0...7
411+
const int64_t il = tid/8; // 0...3
412+
const int64_t ib = tid%8; // 0...7
413413
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
414414
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
415415
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@@ -424,13 +424,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
424424
template<typename dst_t>
425425
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
426426

427-
const int i = blockIdx.x;
427+
const int64_t i = blockIdx.x;
428428
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
429429

430-
const int tid = threadIdx.x;
430+
const int64_t tid = threadIdx.x;
431431
#if QK_K == 256
432-
const int il = tid/8; // 0...3
433-
const int ib = tid%8; // 0...7
432+
const int64_t il = tid/8; // 0...3
433+
const int64_t ib = tid%8; // 0...7
434434
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
435435
const uint8_t * q3 = x[i].qs + 8*ib;
436436
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@@ -452,13 +452,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
452452
template<typename dst_t>
453453
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
454454

455-
const int i = blockIdx.x;
455+
const int64_t i = blockIdx.x;
456456
const block_iq3_s * x = (const block_iq3_s *) vx;
457457

458-
const int tid = threadIdx.x;
458+
const int64_t tid = threadIdx.x;
459459
#if QK_K == 256
460-
const int il = tid/8; // 0...3
461-
const int ib = tid%8; // 0...7
460+
const int64_t il = tid/8; // 0...3
461+
const int64_t ib = tid%8; // 0...7
462462
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
463463
const uint8_t * qs = x[i].qs + 8*ib;
464464
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@@ -478,13 +478,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
478478
template<typename dst_t>
479479
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
480480

481-
const int i = blockIdx.x;
481+
const int64_t i = blockIdx.x;
482482
const block_iq1_s * x = (const block_iq1_s *) vx;
483483

484-
const int tid = threadIdx.x;
484+
const int64_t tid = threadIdx.x;
485485
#if QK_K == 256
486-
const int il = tid/8; // 0...3
487-
const int ib = tid%8; // 0...7
486+
const int64_t il = tid/8; // 0...3
487+
const int64_t ib = tid%8; // 0...7
488488
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
489489
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
490490
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@@ -504,18 +504,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
504504
template<typename dst_t>
505505
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
506506

507-
const int i = blockIdx.x;
507+
const int64_t i = blockIdx.x;
508508
const block_iq1_m * x = (const block_iq1_m *) vx;
509509

510-
const int tid = threadIdx.x;
510+
const int64_t tid = threadIdx.x;
511511
#if QK_K == 256
512-
const int il = tid/8; // 0...3
513-
const int ib = tid%8; // 0...7
512+
const int64_t il = tid/8; // 0...3
513+
const int64_t ib = tid%8; // 0...7
514514
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
515515
const uint16_t * sc = (const uint16_t *)x[i].scales;
516516
iq1m_scale_t scale;
517517
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
518-
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
518+
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
519519
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
520520
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
521521
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
@@ -535,12 +535,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
535535
template<typename dst_t>
536536
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
537537

538-
const int i = blockIdx.x;
538+
const int64_t i = blockIdx.x;
539539
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
540540

541-
const int tid = threadIdx.x;
542-
const int il = tid/8; // 0...3
543-
const int ib = tid%8; // 0...7
541+
const int64_t tid = threadIdx.x;
542+
const int64_t il = tid/8; // 0...3
543+
const int64_t ib = tid%8; // 0...7
544544
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
545545
const uint8_t * q4 = x[ib].qs + 4*il;
546546
const float d = (float)x[ib].d;
@@ -554,12 +554,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
554554
#if QK_K != 64
555555
template<typename dst_t>
556556
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
557-
const int i = blockIdx.x;
557+
const int64_t i = blockIdx.x;
558558
const block_iq4_xs * x = (const block_iq4_xs *)vx;
559559

560-
const int tid = threadIdx.x;
561-
const int il = tid/8; // 0...3
562-
const int ib = tid%8; // 0...7
560+
const int64_t tid = threadIdx.x;
561+
const int64_t il = tid/8; // 0...3
562+
const int64_t ib = tid%8; // 0...7
563563
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
564564
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
565565
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);

0 commit comments

Comments
 (0)