Skip to content

Commit 0bcfc87

Browse files
committed
Fix more int overflow during quant.
1 parent d292407 commit 0bcfc87

File tree

6 files changed

+51
-51
lines changed

6 files changed

+51
-51
lines changed

ggml-cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12251225

12261226
// the main device has a larger memory buffer to hold the results from all GPUs
12271227
// ldc == nrows of the matrix that cuBLAS writes into
1228-
int ldc = id == ctx.device ? ne0 : row_diff;
1228+
int64_t ldc = id == ctx.device ? ne0 : row_diff;
12291229

12301230
const int compute_capability = ggml_cuda_info().devices[id].cc;
12311231

@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
13771377
const int64_t ne0 = dst->ne[0];
13781378
const int64_t ne1 = dst->ne[1];
13791379

1380-
const int nb2 = dst->nb[2];
1381-
const int nb3 = dst->nb[3];
1380+
const int64_t nb2 = dst->nb[2];
1381+
const int64_t nb3 = dst->nb[3];
13821382

13831383
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
13841384
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));

ggml-cuda/convert.cu

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
#define CUDA_Q8_0_NE_ALIGN 2048
55

66
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
7-
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
8-
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
7+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8+
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
99

1010
if (i >= k) {
1111
return;
1212
}
1313

14-
const int ib = i/qk; // block index
14+
const int64_t ib = i/qk; // block index
1515
const int iqs = (i%qk)/qr; // quant index
1616
const int iybs = i - i%qk; // y block start index
1717
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2525
}
2626

2727
template <bool need_check>
28-
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
28+
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
2929
#if __CUDA_ARCH__ >= CC_PASCAL
3030
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
3131

@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
6868
template<typename dst_t>
6969
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
7070

71-
const int i = blockIdx.x;
71+
const int64_t i = blockIdx.x;
7272

7373
// assume 32 threads
7474
const int tid = threadIdx.x;
7575
const int il = tid/8;
7676
const int ir = tid%8;
77-
const int ib = 8*i + ir;
77+
const int64_t ib = 8*i + ir;
7878
if (ib >= nb32) {
7979
return;
8080
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
9696
template<typename dst_t>
9797
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
9898

99-
const int i = blockIdx.x;
99+
const int64_t i = blockIdx.x;
100100

101101
// assume 32 threads
102102
const int tid = threadIdx.x;
103103
const int il = tid/8;
104104
const int ir = tid%8;
105-
const int ib = 8*i + ir;
105+
const int64_t ib = 8*i + ir;
106106
if (ib >= nb32) {
107107
return;
108108
}
@@ -313,14 +313,14 @@ template<typename dst_t>
313313
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
314314
const block_q6_K * x = (const block_q6_K *) vx;
315315

316-
const int i = blockIdx.x;
316+
const int64_t i = blockIdx.x;
317317
#if QK_K == 256
318318

319319
// assume 64 threads - this is very slightly better than the one below
320-
const int tid = threadIdx.x;
321-
const int ip = tid/32; // ip is 0 or 1
322-
const int il = tid - 32*ip; // 0...32
323-
const int is = 8*ip + il/16;
320+
const int64_t tid = threadIdx.x;
321+
const int64_t ip = tid/32; // ip is 0 or 1
322+
const int64_t il = tid - 32*ip; // 0...32
323+
const int64_t is = 8*ip + il/16;
324324

325325
dst_t * y = yy + i*QK_K + 128*ip + il;
326326

@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
337337
#else
338338

339339
// assume 32 threads
340-
const int tid = threadIdx.x;
341-
const int ip = tid/16; // 0 or 1
342-
const int il = tid - 16*ip; // 0...15
340+
const int64_t tid = threadIdx.x;
341+
const int64_t ip = tid/16; // 0 or 1
342+
const int64_t il = tid - 16*ip; // 0...15
343343

344344
dst_t * y = yy + i*QK_K + 16*ip + il;
345345

@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
571571
#endif
572572

573573
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
574-
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
574+
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
575575
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
576576
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
577577
}
578578

579-
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
579+
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
580580
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
581581
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
582582
const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
588588
}
589589

590590
template<typename dst_t>
591-
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
591+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
592592
const int nb = k / QK_K;
593593
#if QK_K == 256
594594
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
598598
}
599599

600600
template<typename dst_t>
601-
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
601+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
602602
const int nb = k / QK_K;
603603
#if QK_K == 256
604604
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
608608
}
609609

610610
template<typename dst_t>
611-
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
611+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
612612
const int nb32 = k / 32;
613613
const int nb = (k + 255) / 256;
614614
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
615615
}
616616

617617
template<typename dst_t>
618-
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
618+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
619619
const int nb32 = k / 32;
620620
const int nb = (k + 255) / 256;
621621
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
622622
}
623623

624624
template<typename dst_t>
625-
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
625+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
626626
const int nb = k / QK_K;
627627
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
628628
}
629629

630630
template<typename dst_t>
631-
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
631+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
632632
const int nb = k / QK_K;
633633
#if QK_K == 256
634634
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
638638
}
639639

640640
template<typename dst_t>
641-
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
641+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
642642
const int nb = k / QK_K;
643643
#if QK_K == 256
644644
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
648648
}
649649

650650
template<typename dst_t>
651-
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
651+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
652652
const int nb = k / QK_K;
653653
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
654654
}
655655

656656
template<typename dst_t>
657-
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
657+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
658658
const int nb = k / QK_K;
659659
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
660660
}
661661

662662
template<typename dst_t>
663-
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663+
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
664664
const int nb = k / QK_K;
665665
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
666666
}
667667

668668
template<typename dst_t>
669-
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
669+
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
670670
const int nb = k / QK_K;
671671
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
672672
}
673673

674674
template<typename dst_t>
675-
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
675+
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
676676
const int nb = k / QK_K;
677677
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
678678
}
679679

680680
template<typename dst_t>
681-
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
681+
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
682682
const int nb = k / QK_K;
683683
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
684684
}
685685

686686
template<typename dst_t>
687-
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
687+
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
688688
const int nb = (k + QK_K - 1) / QK_K;
689689
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
690690
}
691691

692692
template<typename dst_t>
693-
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
693+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
694694
const int nb = k / QK_K;
695695
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
696696
}
697697

698698
template<typename dst_t>
699-
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
699+
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
700700
const int nb = (k + QK_K - 1) / QK_K;
701701
#if QK_K == 64
702702
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
706706
}
707707

708708
template <typename src_t, typename dst_t>
709-
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
710-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
709+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
710+
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
711711

712712
if (i >= k) {
713713
return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
719719
}
720720

721721
template <typename src_t, typename dst_t>
722-
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
722+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
723723
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
724724
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
725725
}

ggml-cuda/convert.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
44

55
template<typename T>
6-
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
6+
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
77

88
typedef to_t_cuda_t<float> to_fp32_cuda_t;
99
typedef to_t_cuda_t<half> to_fp16_cuda_t;

ggml-cuda/dmmv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
577577
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
578578
// qk = quantized weights per x block
579579
// qr = number of quantized weights per data value in x block
580-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
580+
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
581581

582582
if (row >= nrows) {
583583
return;

ggml-cuda/quantize.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
#include "quantize.cuh"
22

3-
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
4-
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
3+
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
4+
const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
55

66
if (ix >= kx_padded) {
77
return;
88
}
99

10-
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
10+
const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
1111

12-
const int i_padded = iy*kx_padded + ix;
12+
const int64_t i_padded = (int64_t)iy*kx_padded + ix;
1313

1414
block_q8_1 * y = (block_q8_1 *) vy;
1515

16-
const int ib = i_padded / QK8_1; // block index
17-
const int iqs = i_padded % QK8_1; // quant index
16+
const int64_t ib = i_padded / QK8_1; // block index
17+
const int64_t iqs = i_padded % QK8_1; // quant index
1818

1919
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
2020
float amax = fabsf(xi);
@@ -36,8 +36,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
3636
reinterpret_cast<half&>(y[ib].ds.y) = sum;
3737
}
3838

39-
void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
40-
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
39+
void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
40+
const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
4141
const dim3 num_blocks(block_num_x, ky, 1);
4242
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
4343
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);

ggml-cuda/quantize.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
#define CUDA_QUANTIZE_BLOCK_SIZE 256
44

5-
void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream);
5+
void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);

0 commit comments

Comments
 (0)