5
5
6
6
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
7
7
static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8
- const int64_t i = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
8
+ const int64_t i = ( int64_t ) 2 *(blockDim .x *blockIdx .x + threadIdx .x );
9
9
10
10
if (i >= k) {
11
11
return ;
12
12
}
13
13
14
14
const int64_t ib = i/qk; // block index
15
- const int iqs = (i%qk)/qr; // quant index
16
- const int iybs = i - i%qk; // y block start index
17
- const int y_offset = qr == 1 ? 1 : qk/2 ;
15
+ const int64_t iqs = (i%qk)/qr; // quant index
16
+ const int64_t iybs = i - i%qk; // y block start index
17
+ const int64_t y_offset = qr == 1 ? 1 : qk/2 ;
18
18
19
19
// dequantize
20
20
dfloat2 v;
@@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
29
29
#if __CUDA_ARCH__ >= CC_PASCAL
30
30
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof (int ) + WARP_SIZE;
31
31
32
- const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx .x ;
32
+ const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx .x ;
33
33
const int * x0 = ((int *) vx) + blockIdx .x * nint;
34
34
half2 * y2 = (half2 *) (y + i0);
35
35
@@ -71,9 +71,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
71
71
const int64_t i = blockIdx .x ;
72
72
73
73
// assume 32 threads
74
- const int tid = threadIdx .x ;
75
- const int il = tid/8 ;
76
- const int ir = tid%8 ;
74
+ const int64_t tid = threadIdx .x ;
75
+ const int64_t il = tid/8 ;
76
+ const int64_t ir = tid%8 ;
77
77
const int64_t ib = 8 *i + ir;
78
78
if (ib >= nb32) {
79
79
return ;
@@ -99,9 +99,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
99
99
const int64_t i = blockIdx .x ;
100
100
101
101
// assume 32 threads
102
- const int tid = threadIdx .x ;
103
- const int il = tid/8 ;
104
- const int ir = tid%8 ;
102
+ const int64_t tid = threadIdx .x ;
103
+ const int64_t il = tid/8 ;
104
+ const int64_t ir = tid%8 ;
105
105
const int64_t ib = 8 *i + ir;
106
106
if (ib >= nb32) {
107
107
return ;
@@ -125,14 +125,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
125
125
template <typename dst_t >
126
126
static __global__ void dequantize_block_q2_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
127
127
128
- const int i = blockIdx .x ;
128
+ const int64_t i = blockIdx .x ;
129
129
const block_q2_K * x = (const block_q2_K *) vx;
130
130
131
- const int tid = threadIdx .x ;
131
+ const int64_t tid = threadIdx .x ;
132
132
#if QK_K == 256
133
- const int n = tid/32 ;
134
- const int l = tid - 32 *n;
135
- const int is = 8 *n + l/16 ;
133
+ const int64_t n = tid/32 ;
134
+ const int64_t l = tid - 32 *n;
135
+ const int64_t is = 8 *n + l/16 ;
136
136
137
137
const uint8_t q = x[i].qs [32 *n + l];
138
138
dst_t * y = yy + i*QK_K + 128 *n;
@@ -144,8 +144,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
144
144
y[l+64 ] = dall * (x[i].scales [is+4 ] & 0xF ) * ((q >> 4 ) & 3 ) - dmin * (x[i].scales [is+4 ] >> 4 );
145
145
y[l+96 ] = dall * (x[i].scales [is+6 ] & 0xF ) * ((q >> 6 ) & 3 ) - dmin * (x[i].scales [is+6 ] >> 4 );
146
146
#else
147
- const int is = tid/16 ; // 0 or 1
148
- const int il = tid%16 ; // 0...15
147
+ const int64_t is = tid/16 ; // 0 or 1
148
+ const int64_t il = tid%16 ; // 0...15
149
149
const uint8_t q = x[i].qs [il] >> (2 *is);
150
150
dst_t * y = yy + i*QK_K + 16 *is + il;
151
151
float dall = __low2half (x[i].dm );
@@ -159,19 +159,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
159
159
template <typename dst_t >
160
160
static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
161
161
162
- const int i = blockIdx .x ;
162
+ const int64_t i = blockIdx .x ;
163
163
const block_q3_K * x = (const block_q3_K *) vx;
164
164
165
165
#if QK_K == 256
166
- const int r = threadIdx .x /4 ;
167
- const int tid = r/2 ;
168
- const int is0 = r%2 ;
169
- const int l0 = 16 *is0 + 4 *(threadIdx .x %4 );
170
- const int n = tid / 4 ;
171
- const int j = tid - 4 *n;
166
+ const int64_t r = threadIdx .x /4 ;
167
+ const int64_t tid = r/2 ;
168
+ const int64_t is0 = r%2 ;
169
+ const int64_t l0 = 16 *is0 + 4 *(threadIdx .x %4 );
170
+ const int64_t n = tid / 4 ;
171
+ const int64_t j = tid - 4 *n;
172
172
173
173
uint8_t m = 1 << (4 *n + j);
174
- int is = 8 *n + 2 *j + is0;
174
+ int64_t is = 8 *n + 2 *j + is0;
175
175
int shift = 2 *j;
176
176
177
177
int8_t us = is < 4 ? (x[i].scales [is-0 ] & 0xF ) | (((x[i].scales [is+8 ] >> 0 ) & 3 ) << 4 ) :
@@ -187,11 +187,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
187
187
188
188
for (int l = l0; l < l0+4 ; ++l) y[l] = dl * ((int8_t )((q[l] >> shift) & 3 ) - ((hm[l] & m) ? 0 : 4 ));
189
189
#else
190
- const int tid = threadIdx .x ;
191
- const int is = tid/16 ; // 0 or 1
192
- const int il = tid%16 ; // 0...15
193
- const int im = il/8 ; // 0...1
194
- const int in = il%8 ; // 0...7
190
+ const int64_t tid = threadIdx .x ;
191
+ const int64_t is = tid/16 ; // 0 or 1
192
+ const int64_t il = tid%16 ; // 0...15
193
+ const int64_t im = il/8 ; // 0...1
194
+ const int64_t in = il%8 ; // 0...7
195
195
196
196
dst_t * y = yy + i*QK_K + 16 *is + il;
197
197
@@ -225,15 +225,15 @@ template<typename dst_t>
225
225
static __global__ void dequantize_block_q4_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
226
226
const block_q4_K * x = (const block_q4_K *) vx;
227
227
228
- const int i = blockIdx .x ;
228
+ const int64_t i = blockIdx .x ;
229
229
230
230
#if QK_K == 256
231
231
// assume 32 threads
232
- const int tid = threadIdx .x ;
233
- const int il = tid/8 ;
234
- const int ir = tid%8 ;
235
- const int is = 2 *il;
236
- const int n = 4 ;
232
+ const int64_t tid = threadIdx .x ;
233
+ const int64_t il = tid/8 ;
234
+ const int64_t ir = tid%8 ;
235
+ const int64_t is = 2 *il;
236
+ const int64_t n = 4 ;
237
237
238
238
dst_t * y = yy + i*QK_K + 64 *il + n*ir;
239
239
@@ -252,7 +252,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
252
252
y[l +32 ] = d2 * (q[l] >> 4 ) - m2;
253
253
}
254
254
#else
255
- const int tid = threadIdx .x ;
255
+ const int64_t tid = threadIdx .x ;
256
256
const uint8_t * q = x[i].qs ;
257
257
dst_t * y = yy + i*QK_K;
258
258
const float d = (float )x[i].dm [0 ];
@@ -266,14 +266,14 @@ template<typename dst_t>
266
266
static __global__ void dequantize_block_q5_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
267
267
const block_q5_K * x = (const block_q5_K *) vx;
268
268
269
- const int i = blockIdx .x ;
269
+ const int64_t i = blockIdx .x ;
270
270
271
271
#if QK_K == 256
272
272
// assume 64 threads - this is very slightly better than the one below
273
- const int tid = threadIdx .x ;
274
- const int il = tid/16 ; // il is in 0...3
275
- const int ir = tid%16 ; // ir is in 0...15
276
- const int is = 2 *il; // is is in 0...6
273
+ const int64_t tid = threadIdx .x ;
274
+ const int64_t il = tid/16 ; // il is in 0...3
275
+ const int64_t ir = tid%16 ; // ir is in 0...15
276
+ const int64_t is = 2 *il; // is is in 0...6
277
277
278
278
dst_t * y = yy + i*QK_K + 64 *il + 2 *ir;
279
279
@@ -296,11 +296,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
296
296
y[32 ] = d2 * ((ql[ 0 ] >> 4 ) + (qh[ 0 ] & hm ? 16 : 0 )) - m2;
297
297
y[33 ] = d2 * ((ql[ 1 ] >> 4 ) + (qh[ 1 ] & hm ? 16 : 0 )) - m2;
298
298
#else
299
- const int tid = threadIdx .x ;
299
+ const int64_t tid = threadIdx .x ;
300
300
const uint8_t q = x[i].qs [tid];
301
- const int im = tid/8 ; // 0...3
302
- const int in = tid%8 ; // 0...7
303
- const int is = tid/16 ; // 0 or 1
301
+ const int64_t im = tid/8 ; // 0...3
302
+ const int64_t in = tid%8 ; // 0...7
303
+ const int64_t is = tid/16 ; // 0 or 1
304
304
const uint8_t h = x[i].qh [in] >> im;
305
305
const float d = x[i].d ;
306
306
dst_t * y = yy + i*QK_K + tid;
@@ -357,13 +357,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
357
357
template <typename dst_t >
358
358
static __global__ void dequantize_block_iq2_xxs (const void * __restrict__ vx, dst_t * __restrict__ yy) {
359
359
360
- const int i = blockIdx .x ;
360
+ const int64_t i = blockIdx .x ;
361
361
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
362
362
363
- const int tid = threadIdx .x ;
363
+ const int64_t tid = threadIdx .x ;
364
364
#if QK_K == 256
365
- const int il = tid/8 ; // 0...3
366
- const int ib = tid%8 ; // 0...7
365
+ const int64_t il = tid/8 ; // 0...3
366
+ const int64_t ib = tid%8 ; // 0...7
367
367
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
368
368
const uint16_t * q2 = x[i].qs + 4 *ib;
369
369
const uint8_t * aux8 = (const uint8_t *)q2;
@@ -381,13 +381,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
381
381
template <typename dst_t >
382
382
static __global__ void dequantize_block_iq2_xs (const void * __restrict__ vx, dst_t * __restrict__ yy) {
383
383
384
- const int i = blockIdx .x ;
384
+ const int64_t i = blockIdx .x ;
385
385
const block_iq2_xs * x = (const block_iq2_xs *) vx;
386
386
387
- const int tid = threadIdx .x ;
387
+ const int64_t tid = threadIdx .x ;
388
388
#if QK_K == 256
389
- const int il = tid/8 ; // 0...3
390
- const int ib = tid%8 ; // 0...7
389
+ const int64_t il = tid/8 ; // 0...3
390
+ const int64_t ib = tid%8 ; // 0...7
391
391
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
392
392
const uint16_t * q2 = x[i].qs + 4 *ib;
393
393
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511 ));
@@ -403,13 +403,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
403
403
template <typename dst_t >
404
404
static __global__ void dequantize_block_iq2_s (const void * __restrict__ vx, dst_t * __restrict__ yy) {
405
405
406
- const int i = blockIdx .x ;
406
+ const int64_t i = blockIdx .x ;
407
407
const block_iq2_s * x = (const block_iq2_s *) vx;
408
408
409
- const int tid = threadIdx .x ;
409
+ const int64_t tid = threadIdx .x ;
410
410
#if QK_K == 256
411
- const int il = tid/8 ; // 0...3
412
- const int ib = tid%8 ; // 0...7
411
+ const int64_t il = tid/8 ; // 0...3
412
+ const int64_t ib = tid%8 ; // 0...7
413
413
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
414
414
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs [4 *ib+il] | ((x[i].qh [ib] << (8 -2 *il)) & 0x300 )));
415
415
const float d = (float )x[i].d * (0 .5f + ((x[i].scales [ib] >> 4 *(il/2 )) & 0xf )) * 0 .25f ;
@@ -424,13 +424,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
424
424
template <typename dst_t >
425
425
static __global__ void dequantize_block_iq3_xxs (const void * __restrict__ vx, dst_t * __restrict__ yy) {
426
426
427
- const int i = blockIdx .x ;
427
+ const int64_t i = blockIdx .x ;
428
428
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
429
429
430
- const int tid = threadIdx .x ;
430
+ const int64_t tid = threadIdx .x ;
431
431
#if QK_K == 256
432
- const int il = tid/8 ; // 0...3
433
- const int ib = tid%8 ; // 0...7
432
+ const int64_t il = tid/8 ; // 0...3
433
+ const int64_t ib = tid%8 ; // 0...7
434
434
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
435
435
const uint8_t * q3 = x[i].qs + 8 *ib;
436
436
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4 ) + 2 *ib;
@@ -452,13 +452,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
452
452
template <typename dst_t >
453
453
static __global__ void dequantize_block_iq3_s (const void * __restrict__ vx, dst_t * __restrict__ yy) {
454
454
455
- const int i = blockIdx .x ;
455
+ const int64_t i = blockIdx .x ;
456
456
const block_iq3_s * x = (const block_iq3_s *) vx;
457
457
458
- const int tid = threadIdx .x ;
458
+ const int64_t tid = threadIdx .x ;
459
459
#if QK_K == 256
460
- const int il = tid/8 ; // 0...3
461
- const int ib = tid%8 ; // 0...7
460
+ const int64_t il = tid/8 ; // 0...3
461
+ const int64_t ib = tid%8 ; // 0...7
462
462
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
463
463
const uint8_t * qs = x[i].qs + 8 *ib;
464
464
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2 *il+0 ] | ((x[i].qh [ib] << (8 -2 *il)) & 256 )));
@@ -478,13 +478,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
478
478
template <typename dst_t >
479
479
static __global__ void dequantize_block_iq1_s (const void * __restrict__ vx, dst_t * __restrict__ yy) {
480
480
481
- const int i = blockIdx .x ;
481
+ const int64_t i = blockIdx .x ;
482
482
const block_iq1_s * x = (const block_iq1_s *) vx;
483
483
484
- const int tid = threadIdx .x ;
484
+ const int64_t tid = threadIdx .x ;
485
485
#if QK_K == 256
486
- const int il = tid/8 ; // 0...3
487
- const int ib = tid%8 ; // 0...7
486
+ const int64_t il = tid/8 ; // 0...3
487
+ const int64_t ib = tid%8 ; // 0...7
488
488
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
489
489
const float delta = x[i].qh [ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
490
490
const float d = (float )x[i].d * (2 *((x[i].qh [ib] >> 12 ) & 7 ) + 1 );
@@ -504,18 +504,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
504
504
template <typename dst_t >
505
505
static __global__ void dequantize_block_iq1_m (const void * __restrict__ vx, dst_t * __restrict__ yy) {
506
506
507
- const int i = blockIdx .x ;
507
+ const int64_t i = blockIdx .x ;
508
508
const block_iq1_m * x = (const block_iq1_m *) vx;
509
509
510
- const int tid = threadIdx .x ;
510
+ const int64_t tid = threadIdx .x ;
511
511
#if QK_K == 256
512
- const int il = tid/8 ; // 0...3
513
- const int ib = tid%8 ; // 0...7
512
+ const int64_t il = tid/8 ; // 0...3
513
+ const int64_t ib = tid%8 ; // 0...7
514
514
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
515
515
const uint16_t * sc = (const uint16_t *)x[i].scales ;
516
516
iq1m_scale_t scale;
517
517
scale.u16 = (sc[0 ] >> 12 ) | ((sc[1 ] >> 8 ) & 0x00f0 ) | ((sc[2 ] >> 4 ) & 0x0f00 ) | (sc[3 ] & 0xf000 );
518
- const int ib16 = 2 *ib + il/2 ; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
518
+ const int64_t ib16 = 2 *ib + il/2 ; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
519
519
const float d = (float )scale.f16 * (2 *((sc[ib16/4 ] >> 3 *(ib16%4 )) & 0x7 ) + 1 );
520
520
const float delta = x[i].qh [2 *ib+il/2 ] & (0x08 << 4 *(il%2 )) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
521
521
uint32_t grid32[2 ]; const int8_t * q = (const int8_t *)grid32;
@@ -535,12 +535,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
535
535
template <typename dst_t >
536
536
static __global__ void dequantize_block_iq4_nl (const void * __restrict__ vx, dst_t * __restrict__ yy) {
537
537
538
- const int i = blockIdx .x ;
538
+ const int64_t i = blockIdx .x ;
539
539
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
540
540
541
- const int tid = threadIdx .x ;
542
- const int il = tid/8 ; // 0...3
543
- const int ib = tid%8 ; // 0...7
541
+ const int64_t tid = threadIdx .x ;
542
+ const int64_t il = tid/8 ; // 0...3
543
+ const int64_t ib = tid%8 ; // 0...7
544
544
dst_t * y = yy + i*QK_K + 32 *ib + 4 *il;
545
545
const uint8_t * q4 = x[ib].qs + 4 *il;
546
546
const float d = (float )x[ib].d ;
@@ -554,12 +554,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
554
554
#if QK_K != 64
555
555
template <typename dst_t >
556
556
static __global__ void dequantize_block_iq4_xs (const void * __restrict__ vx, dst_t * __restrict__ yy) {
557
- const int i = blockIdx .x ;
557
+ const int64_t i = blockIdx .x ;
558
558
const block_iq4_xs * x = (const block_iq4_xs *)vx;
559
559
560
- const int tid = threadIdx .x ;
561
- const int il = tid/8 ; // 0...3
562
- const int ib = tid%8 ; // 0...7
560
+ const int64_t tid = threadIdx .x ;
561
+ const int64_t il = tid/8 ; // 0...3
562
+ const int64_t ib = tid%8 ; // 0...7
563
563
dst_t * y = yy + i*QK_K + 32 *ib + 4 *il;
564
564
const uint8_t * q4 = x[i].qs + 16 *ib + 4 *il;
565
565
const float d = (float )x[i].d * ((((x[i].scales_l [ib/2 ] >> 4 *(ib%2 )) & 0xf ) | (((x[i].scales_h >> 2 *ib) & 3 ) << 4 )) - 32 );
0 commit comments