@@ -467,7 +467,8 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
467
467
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
468
468
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
469
469
470
- #define MUL_MAT_SRC1_COL_STRIDE 128
470
+ #define MUL_MAT_SRC1_COL_STRIDE_MMQ 128
471
+ #define MUL_MAT_SRC1_COL_STRIDE 4096
471
472
472
473
#define MAX_STREAMS 8
473
474
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
@@ -7158,7 +7159,10 @@ static void ggml_cuda_op_mul_mat(
7158
7159
CUDA_CHECK (cudaEventRecord (src0_extra->events [g_main_device][0 ], g_cudaStreams[g_main_device][0 ]));
7159
7160
}
7160
7161
7161
- const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
7162
+ const int64_t src1_col_stride = !split || used_devices == 1 ? ne11 :
7163
+ convert_src1_to_q8_1 ? MUL_MAT_SRC1_COL_STRIDE_MMQ :
7164
+ MUL_MAT_SRC1_COL_STRIDE;
7165
+
7162
7166
for (int64_t src1_col_0 = 0 ; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
7163
7167
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0 ;
7164
7168
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7296,7 +7300,7 @@ static void ggml_cuda_op_mul_mat(
7296
7300
7297
7301
// main device waits for all other devices to be finished
7298
7302
if (split && g_device_count > 1 ) {
7299
- int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1 ) / MUL_MAT_SRC1_COL_STRIDE ;
7303
+ int64_t is_max = (ne11 + src1_col_stride - 1 ) / src1_col_stride ;
7300
7304
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
7301
7305
7302
7306
CUDA_CHECK (ggml_cuda_set_device (g_main_device));
0 commit comments