Skip to content

Commit 17c97fb

Browse files
CUDA: mul_mat_vec_q max. batch size 8 -> 4 (#5370)
1 parent b08f22c commit 17c97fb

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

ggml-cuda.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6831,7 +6831,7 @@ static void mul_mat_vec_q_cuda(
68316831
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
68326832

68336833
GGML_ASSERT(ncols_x % qk == 0);
6834-
GGML_ASSERT(ncols_y <= 8);
6834+
GGML_ASSERT(ncols_y <= 4);
68356835

68366836
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
68376837
const dim3 block_nums(block_num_y, 1, 1);
@@ -6853,22 +6853,22 @@ static void mul_mat_vec_q_cuda(
68536853
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
68546854
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
68556855
break;
6856-
case 5:
6857-
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859-
break;
6860-
case 6:
6861-
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863-
break;
6864-
case 7:
6865-
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867-
break;
6868-
case 8:
6869-
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871-
break;
6856+
// case 5:
6857+
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859+
// break;
6860+
// case 6:
6861+
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863+
// break;
6864+
// case 7:
6865+
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867+
// break;
6868+
// case 8:
6869+
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870+
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871+
// break;
68726872
default:
68736873
GGML_ASSERT(false);
68746874
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
@@ -9909,7 +9909,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99099909
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
99109910
}
99119911
} else {
9912-
if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
9912+
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
99139913
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
99149914
} else if (use_mul_mat_q) {
99159915
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);

0 commit comments

Comments
 (0)