Skip to content

Commit da2d8cb

Browse files
slarenolexiyb
authored andcommitted
ggml-cuda : move row numbers to x grid dim in mmv kernels (ggml-org#3921)
1 parent d98979c commit da2d8cb

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

ggml-cuda.cu

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
989989

990990
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
991991

992-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
992+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
993993
if (row > nrows) return;
994994

995995
const int num_blocks_per_row = ncols / QK_K;
@@ -1093,7 +1093,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
10931093

10941094
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
10951095

1096-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
1096+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
10971097
if (row > nrows) return;
10981098

10991099
const int num_blocks_per_row = ncols / QK_K;
@@ -1197,7 +1197,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
11971197

11981198
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
11991199

1200-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
1200+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
12011201
if (row > nrows) return;
12021202
const int num_blocks_per_row = ncols / QK_K;
12031203
const int ib0 = row*num_blocks_per_row;
@@ -1451,7 +1451,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
14511451

14521452
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
14531453

1454-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
1454+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
14551455
if (row > nrows) return;
14561456

14571457
const int num_blocks_per_row = ncols / QK_K;
@@ -4261,7 +4261,7 @@ template <bool need_check> static __global__ void
42614261

42624262
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
42634263
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
4264-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
4264+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
42654265

42664266
if (row >= nrows) {
42674267
return;
@@ -4301,7 +4301,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
43014301
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
43024302
// qk = quantized weights per x block
43034303
// qr = number of quantized weights per data value in x block
4304-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
4304+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
43054305

43064306
if (row >= nrows) {
43074307
return;
@@ -4874,7 +4874,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
48744874
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
48754875
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
48764876
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4877-
const dim3 block_nums(1, block_num_y, 1);
4877+
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
4878+
const dim3 block_nums(block_num_y, 1, 1);
48784879
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
48794880
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
48804881
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4883,7 +4884,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
48834884
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
48844885
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
48854886
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4886-
const dim3 block_nums(1, block_num_y, 1);
4887+
const dim3 block_nums(block_num_y, 1, 1);
48874888
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
48884889
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
48894890
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4892,7 +4893,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
48924893
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
48934894
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
48944895
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4895-
const dim3 block_nums(1, block_num_y, 1);
4896+
const dim3 block_nums(block_num_y, 1, 1);
48964897
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
48974898
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
48984899
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4901,7 +4902,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
49014902
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49024903
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
49034904
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4904-
const dim3 block_nums(1, block_num_y, 1);
4905+
const dim3 block_nums(block_num_y, 1, 1);
49054906
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49064907
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
49074908
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4910,7 +4911,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
49104911
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49114912
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
49124913
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4913-
const dim3 block_nums(1, block_num_y, 1);
4914+
const dim3 block_nums(block_num_y, 1, 1);
49144915
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49154916
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
49164917
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
@@ -4920,7 +4921,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
49204921
GGML_ASSERT(ncols % QK_K == 0);
49214922
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
49224923
const int block_num_y = (nrows + ny - 1) / ny;
4923-
const dim3 block_nums(1, block_num_y, 1);
4924+
const dim3 block_nums(block_num_y, 1, 1);
49244925
const dim3 block_dims(32, ny, 1);
49254926
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
49264927
}
@@ -4929,7 +4930,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
49294930
GGML_ASSERT(ncols % QK_K == 0);
49304931
const int ny = 2 / K_QUANTS_PER_ITERATION;
49314932
const int block_num_y = (nrows + ny - 1) / ny;
4932-
const dim3 block_nums(1, block_num_y, 1);
4933+
const dim3 block_nums(block_num_y, 1, 1);
49334934
const dim3 block_dims(32, ny, 1);
49344935
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
49354936
}
@@ -4938,7 +4939,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
49384939
GGML_ASSERT(ncols % QK_K == 0);
49394940
const int ny = 2 / K_QUANTS_PER_ITERATION;
49404941
const int block_num_y = (nrows + ny - 1) / ny;
4941-
const dim3 block_nums(1, block_num_y, 1);
4942+
const dim3 block_nums(block_num_y, 1, 1);
49424943
const dim3 block_dims(32, ny, 1);
49434944
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
49444945
}
@@ -4953,15 +4954,15 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
49534954
GGML_ASSERT(ncols % QK_K == 0);
49544955
const int ny = 2 / K_QUANTS_PER_ITERATION;
49554956
const int block_num_y = (nrows + ny - 1) / ny;
4956-
const dim3 block_nums(1, block_num_y, 1);
4957+
const dim3 block_nums(block_num_y, 1, 1);
49574958
const dim3 block_dims(32, ny, 1);
49584959
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
49594960
}
49604961

49614962
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49624963
GGML_ASSERT(ncols % QK4_0 == 0);
49634964
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4964-
const dim3 block_nums(1, block_num_y, 1);
4965+
const dim3 block_nums(block_num_y, 1, 1);
49654966
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49664967
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
49674968
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4970,7 +4971,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
49704971
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49714972
GGML_ASSERT(ncols % QK4_1 == 0);
49724973
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4973-
const dim3 block_nums(1, block_num_y, 1);
4974+
const dim3 block_nums(block_num_y, 1, 1);
49744975
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49754976
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
49764977
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4979,7 +4980,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
49794980
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49804981
GGML_ASSERT(ncols % QK5_0 == 0);
49814982
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4982-
const dim3 block_nums(1, block_num_y, 1);
4983+
const dim3 block_nums(block_num_y, 1, 1);
49834984
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49844985
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
49854986
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4988,7 +4989,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
49884989
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49894990
GGML_ASSERT(ncols % QK5_1 == 0);
49904991
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
4991-
const dim3 block_nums(1, block_num_y, 1);
4992+
const dim3 block_nums(block_num_y, 1, 1);
49924993
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
49934994
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
49944995
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -4997,7 +4998,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
49974998
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
49984999
GGML_ASSERT(ncols % QK8_0 == 0);
49995000
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5000-
const dim3 block_nums(1, block_num_y, 1);
5001+
const dim3 block_nums(block_num_y, 1, 1);
50015002
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50025003
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
50035004
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5006,7 +5007,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
50065007
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50075008
GGML_ASSERT(ncols % QK_K == 0);
50085009
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5009-
const dim3 block_nums(1, block_num_y, 1);
5010+
const dim3 block_nums(block_num_y, 1, 1);
50105011
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50115012
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
50125013
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5015,7 +5016,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
50155016
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50165017
GGML_ASSERT(ncols % QK_K == 0);
50175018
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5018-
const dim3 block_nums(1, block_num_y, 1);
5019+
const dim3 block_nums(block_num_y, 1, 1);
50195020
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50205021
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
50215022
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5024,7 +5025,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
50245025
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50255026
GGML_ASSERT(ncols % QK_K == 0);
50265027
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5027-
const dim3 block_nums(1, block_num_y, 1);
5028+
const dim3 block_nums(block_num_y, 1, 1);
50285029
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50295030
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
50305031
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5033,7 +5034,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
50335034
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50345035
GGML_ASSERT(ncols % QK_K == 0);
50355036
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5036-
const dim3 block_nums(1, block_num_y, 1);
5037+
const dim3 block_nums(block_num_y, 1, 1);
50375038
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50385039
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
50395040
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5042,7 +5043,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
50425043
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50435044
GGML_ASSERT(ncols % QK_K == 0);
50445045
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5045-
const dim3 block_nums(1, block_num_y, 1);
5046+
const dim3 block_nums(block_num_y, 1, 1);
50465047
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50475048
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
50485049
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
@@ -5061,7 +5062,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
50615062
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
50625063
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
50635064
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
5064-
const dim3 block_nums(1, block_num_y, 1);
5065+
const dim3 block_nums(block_num_y, 1, 1);
50655066
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
50665067
dequantize_mul_mat_vec<1, 1, convert_f16>
50675068
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);

0 commit comments

Comments
 (0)