Skip to content

Commit 9584527

Browse files
committed
cuda : tweak mm stride to double perf on P40 + GTX 970
1 parent 3e73d31 commit 9584527

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,8 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
467467
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
468468
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
469469

470-
#define MUL_MAT_SRC1_COL_STRIDE 128
470+
#define MUL_MAT_SRC1_COL_STRIDE_MMQ 128
471+
#define MUL_MAT_SRC1_COL_STRIDE 4096
471472

472473
#define MAX_STREAMS 8
473474
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
@@ -7158,7 +7159,10 @@ static void ggml_cuda_op_mul_mat(
71587159
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
71597160
}
71607161

7161-
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
7162+
const int64_t src1_col_stride = !split || used_devices == 1 ? ne11 :
7163+
convert_src1_to_q8_1 ? MUL_MAT_SRC1_COL_STRIDE_MMQ :
7164+
MUL_MAT_SRC1_COL_STRIDE;
7165+
71627166
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
71637167
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
71647168
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7296,7 +7300,7 @@ static void ggml_cuda_op_mul_mat(
72967300

72977301
// main device waits for all other devices to be finished
72987302
if (split && g_device_count > 1) {
7299-
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
7303+
int64_t is_max = (ne11 + src1_col_stride - 1) / src1_col_stride;
73007304
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
73017305

73027306
CUDA_CHECK(ggml_cuda_set_device(g_main_device));

0 commit comments

Comments
 (0)