Skip to content

Commit 3b495c3

Browse files
JohannesGaesslerjordankanter
authored andcommitted
CUDA: fixed mmvq kernel for bs 2,3,4 and -sm row (ggml-org#5386)
1 parent 829ac69 commit 3b495c3

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

ggml-cuda.cu

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5313,7 +5313,7 @@ template <bool need_check> static __global__ void
53135313
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
53145314
static __global__ void mul_mat_vec_q(
53155315
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316-
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
5316+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
53175317

53185318
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
53195319

@@ -5352,7 +5352,7 @@ static __global__ void mul_mat_vec_q(
53525352
tmp[j] = warp_reduce_sum(tmp[j]);
53535353

53545354
if (threadIdx.x == 0) {
5355-
dst[j*nrows_x + row] = tmp[j];
5355+
dst[j*nrows_dst + row] = tmp[j];
53565356
}
53575357
}
53585358
}
@@ -6828,7 +6828,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
68286828
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
68296829
static void mul_mat_vec_q_cuda(
68306830
const void * vx, const void * vy, float * dst,
6831-
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
6831+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
68326832

68336833
GGML_ASSERT(ncols_x % qk == 0);
68346834
GGML_ASSERT(ncols_y <= 4);
@@ -6839,40 +6839,40 @@ static void mul_mat_vec_q_cuda(
68396839
switch (ncols_y) {
68406840
case 1:
68416841
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
6842-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6842+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68436843
break;
68446844
case 2:
68456845
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
6846-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6846+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68476847
break;
68486848
case 3:
68496849
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
6850-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6850+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68516851
break;
68526852
case 4:
68536853
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
6854-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6854+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68556855
break;
68566856
// case 5:
68576857
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6858+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68596859
// break;
68606860
// case 6:
68616861
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6862+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68636863
// break;
68646864
// case 7:
68656865
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6866+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68676867
// break;
68686868
// case 8:
68696869
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6870+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68716871
// break;
68726872
default:
68736873
GGML_ASSERT(false);
68746874
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875-
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6875+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
68766876
break;
68776877
}
68786878
}
@@ -8391,7 +8391,7 @@ static void ggml_cuda_op_mul_mat_q(
83918391
CUDA_CHECK(cudaGetDevice(&id));
83928392

83938393
// the main device has a larger memory buffer to hold the results from all GPUs
8394-
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
8394+
// nrows_dst == nrows of the matrix that the kernel writes into
83958395
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
83968396

83978397
switch (src0->type) {
@@ -8525,58 +8525,70 @@ static void ggml_cuda_op_mul_mat_vec_q(
85258525
const int64_t ne00 = src0->ne[0];
85268526
const int64_t row_diff = row_high - row_low;
85278527

8528+
const int64_t ne10 = src1->ne[0];
8529+
GGML_ASSERT(ne10 % QK8_1 == 0);
8530+
8531+
const int64_t ne0 = dst->ne[0];
8532+
8533+
int id;
8534+
CUDA_CHECK(cudaGetDevice(&id));
8535+
8536+
// the main device has a larger memory buffer to hold the results from all GPUs
8537+
// nrows_dst == nrows of the matrix that the kernel writes into
8538+
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
8539+
85288540
switch (src0->type) {
85298541
case GGML_TYPE_Q4_0:
85308542
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8531-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8543+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85328544
break;
85338545
case GGML_TYPE_Q4_1:
85348546
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8535-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8547+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85368548
break;
85378549
case GGML_TYPE_Q5_0:
85388550
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8539-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8551+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85408552
break;
85418553
case GGML_TYPE_Q5_1:
85428554
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8543-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8555+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85448556
break;
85458557
case GGML_TYPE_Q8_0:
85468558
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8547-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8559+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85488560
break;
85498561
case GGML_TYPE_Q2_K:
85508562
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8551-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8563+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85528564
break;
85538565
case GGML_TYPE_Q3_K:
85548566
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8555-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8567+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85568568
break;
85578569
case GGML_TYPE_Q4_K:
85588570
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8559-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8571+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85608572
break;
85618573
case GGML_TYPE_Q5_K:
85628574
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8563-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8575+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85648576
break;
85658577
case GGML_TYPE_Q6_K:
85668578
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8567-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8579+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85688580
break;
85698581
case GGML_TYPE_IQ2_XXS:
85708582
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
8571-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8583+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85728584
break;
85738585
case GGML_TYPE_IQ2_XS:
85748586
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
8575-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8587+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85768588
break;
85778589
case GGML_TYPE_IQ3_XXS:
85788590
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8579-
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8591+
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85808592
break;
85818593
default:
85828594
GGML_ASSERT(false);
@@ -9909,7 +9921,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99099921
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
99109922
}
99119923
} else {
9912-
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
9924+
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
99139925
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
99149926
} else if (use_mul_mat_q) {
99159927
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);

0 commit comments

Comments
 (0)