@@ -6893,6 +6893,8 @@ static void ggml_cuda_op_mul_mat(
6893
6893
int64_t row_low[GGML_CUDA_MAX_DEVICES];
6894
6894
int64_t row_high[GGML_CUDA_MAX_DEVICES];
6895
6895
6896
+ int used_devices = 0 ;
6897
+
6896
6898
for (int64_t id = 0 ; id < g_device_count; ++id) {
6897
6899
// by default, use all rows
6898
6900
row_low[id] = 0 ;
@@ -6920,6 +6922,8 @@ static void ggml_cuda_op_mul_mat(
6920
6922
continue ;
6921
6923
}
6922
6924
6925
+ used_devices++;
6926
+
6923
6927
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
6924
6928
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
6925
6929
@@ -6958,12 +6962,12 @@ static void ggml_cuda_op_mul_mat(
6958
6962
6959
6963
// if multiple devices are used they need to wait for the main device
6960
6964
// here an event is recorded that signals that the main device has finished calculating the input data
6961
- if (split && g_device_count > 1 ) {
6965
+ if (split && used_devices > 1 ) {
6962
6966
CUDA_CHECK (ggml_cuda_set_device (g_main_device));
6963
6967
CUDA_CHECK (cudaEventRecord (src0_extra->events [g_main_device][0 ], g_cudaStreams[g_main_device][0 ]));
6964
6968
}
6965
6969
6966
- const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
6970
+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
6967
6971
for (int64_t src1_col_0 = 0 ; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
6968
6972
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0 ;
6969
6973
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7079,6 +7083,9 @@ static void ggml_cuda_op_mul_mat(
7079
7083
}
7080
7084
7081
7085
for (int64_t id = 0 ; id < g_device_count; ++id) {
7086
+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
7087
+ continue ;
7088
+ }
7082
7089
CUDA_CHECK (ggml_cuda_set_device (id));
7083
7090
7084
7091
// free buffers again when done
@@ -7103,6 +7110,9 @@ static void ggml_cuda_op_mul_mat(
7103
7110
7104
7111
CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7105
7112
for (int64_t id = 0 ; id < g_device_count; ++id) {
7113
+ if (row_low[id] == row_high[id]) {
7114
+ continue ;
7115
+ }
7106
7116
for (int64_t is = 0 ; is < is_max; ++is) {
7107
7117
CUDA_CHECK (cudaStreamWaitEvent (g_cudaStreams[g_main_device][0 ], src0_extra->events [id][is], 0 ));
7108
7118
}
@@ -7400,7 +7410,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7400
7410
7401
7411
static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7402
7412
const bool all_on_device =
7403
- (src0->backend == GGML_BACKEND_GPU) &&
7413
+ (src0->backend == GGML_BACKEND_GPU || src0-> backend == GGML_BACKEND_GPU_SPLIT ) &&
7404
7414
(src1->backend == GGML_BACKEND_GPU) &&
7405
7415
( dst->backend == GGML_BACKEND_GPU);
7406
7416
0 commit comments