Skip to content

Commit c6a933c

Browse files
fix out-of-bounds for q4_0
1 parent 214fc99 commit c6a933c

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ggml-cuda.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
162162
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
163163
typedef void (*load_tiles_cuda_t)(
164164
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
165-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row);
165+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
166166
typedef float (*vec_dot_q_mul_mat_cuda_t)(
167167
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
168168
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
@@ -1406,7 +1406,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 **
14061406

14071407
static __device__ __forceinline__ void load_tiles_q4_0(
14081408
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1409-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1409+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
14101410

14111411
__builtin_assume(i_offset >= 0);
14121412
__builtin_assume(i_offset < 8);
@@ -1420,7 +1420,7 @@ static __device__ __forceinline__ void load_tiles_q4_0(
14201420

14211421
#pragma unroll
14221422
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
1423-
const int i = i0 + i_offset;
1423+
const int i = min(i0 + i_offset, i_max);
14241424

14251425
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
14261426

@@ -1515,7 +1515,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 **
15151515

15161516
static __device__ __forceinline__ void load_tiles_q4_1(
15171517
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1518-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1518+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
15191519

15201520
__builtin_assume(i_offset >= 0);
15211521
__builtin_assume(i_offset < 8);
@@ -1619,7 +1619,7 @@ static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 **
16191619

16201620
static __device__ __forceinline__ void load_tiles_q5_0(
16211621
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1622-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1622+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
16231623

16241624
__builtin_assume(i_offset >= 0);
16251625
__builtin_assume(i_offset < 8);
@@ -1735,7 +1735,7 @@ static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 **
17351735

17361736
static __device__ __forceinline__ void load_tiles_q5_1(
17371737
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1738-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1738+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
17391739

17401740
__builtin_assume(i_offset >= 0);
17411741
__builtin_assume(i_offset < 8);
@@ -1826,7 +1826,7 @@ static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 **
18261826

18271827
static __device__ __forceinline__ void load_tiles_q8_0(
18281828
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1829-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1829+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
18301830

18311831
__builtin_assume(i_offset >= 0);
18321832
__builtin_assume(i_offset < 8);
@@ -1949,7 +1949,7 @@ static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 **
19491949

19501950
static __device__ __forceinline__ void load_tiles_q2_K(
19511951
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1952-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
1952+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
19531953

19541954
__builtin_assume(i_offset >= 0);
19551955
__builtin_assume(i_offset < 8);
@@ -2101,7 +2101,7 @@ static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 **
21012101

21022102
static __device__ __forceinline__ void load_tiles_q3_K(
21032103
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
2104-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
2104+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
21052105

21062106
__builtin_assume(i_offset >= 0);
21072107
__builtin_assume(i_offset < 8);
@@ -2322,7 +2322,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 **
23222322

23232323
static __device__ __forceinline__ void load_tiles_q4_K(
23242324
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
2325-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
2325+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
23262326

23272327
__builtin_assume(i_offset >= 0);
23282328
__builtin_assume(i_offset < 8);
@@ -2550,7 +2550,7 @@ static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 **
25502550

25512551
static __device__ __forceinline__ void load_tiles_q5_K(
25522552
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
2553-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
2553+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
25542554

25552555
__builtin_assume(i_offset >= 0);
25562556
__builtin_assume(i_offset < 8);
@@ -2719,7 +2719,7 @@ static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 **
27192719

27202720
static __device__ __forceinline__ void load_tiles_q6_K(
27212721
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
2722-
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
2722+
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
27232723

27242724
__builtin_assume(i_offset >= 0);
27252725
__builtin_assume(i_offset < 8);
@@ -2849,7 +2849,7 @@ static __global__ void mul_mat_q(
28492849
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
28502850

28512851
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
2852-
tid_y, tid_x, blocks_per_row_x);
2852+
tid_y, nrows_x-row_x_0-1, tid_x, blocks_per_row_x);
28532853

28542854
for (int ir = 0; ir < qr; ++ir) {
28552855
const int kqs = ir*WARP_SIZE + tid_x;

0 commit comments

Comments
 (0)