Skip to content

Commit d50f889

Browse files
CUDA: stream-k decomposition for MMQ (#8018)
* CUDA: stream-k decomposition for MMQ * fix undefined memory reads for small matrices
1 parent 2075a66 commit d50f889

File tree

4 files changed

+291
-112
lines changed

4 files changed

+291
-112
lines changed

ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
635635
}
636636

637637
const int cc = ggml_cuda_info().devices[id].cc;
638-
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
638+
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
639639
}
640640
return row_rounding;
641641
}

ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
652652
}
653653

654654
// Round rows to this value for --split-mode row:
655-
static int get_mmq_y_host(const int cc, const int mmq_x) {
656-
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
655+
static int get_mmq_y_host(const int cc) {
656+
return cc >= CC_VOLTA ? 128 : 64;
657657
}
658658

659659
//////////////////////

ggml-cuda/mmq.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
3030

3131
switch (src0->type) {
3232
case GGML_TYPE_Q4_0:
33-
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
33+
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
3434
break;
3535
case GGML_TYPE_Q4_1:
36-
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
36+
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
3737
break;
3838
case GGML_TYPE_Q5_0:
39-
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
39+
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
4040
break;
4141
case GGML_TYPE_Q5_1:
42-
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
42+
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
4343
break;
4444
case GGML_TYPE_Q8_0:
45-
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
45+
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
4646
break;
4747
case GGML_TYPE_Q2_K:
48-
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
48+
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
4949
break;
5050
case GGML_TYPE_Q3_K:
51-
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
51+
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
5252
break;
5353
case GGML_TYPE_Q4_K:
54-
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
54+
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
5555
break;
5656
case GGML_TYPE_Q5_K:
57-
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
57+
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
5858
break;
5959
case GGML_TYPE_Q6_K:
60-
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
60+
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
6161
break;
6262
default:
6363
GGML_ASSERT(false);

0 commit comments

Comments
 (0)