Skip to content

Commit 866b502

Browse files
ggml_cuda_set_device
1 parent bd79c94 commit 866b502

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

ggml-cuda.cu

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,16 @@ struct ggml_tensor_extra_gpu {
409409
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
410410
};
411411

412+
cudaError_t ggml_cuda_set_device(int device) {
413+
static int current_device = -1;
414+
415+
if (device == current_device) {
416+
return cudaSuccess;
417+
}
418+
419+
return cudaSetDevice(device);
420+
}
421+
412422
static int g_device_count = -1;
413423
static int g_main_device = 0;
414424
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
@@ -5151,7 +5161,7 @@ void ggml_init_cublas() {
51515161
}
51525162

51535163
for (int64_t id = 0; id < g_device_count; ++id) {
5154-
CUDA_CHECK(cudaSetDevice(id));
5164+
CUDA_CHECK(ggml_cuda_set_device(id));
51555165

51565166
// create cuda streams
51575167
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
@@ -5795,7 +5805,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
57955805
size_t src1_asf = 0;
57965806
size_t dst_asf = 0;
57975807

5798-
cudaSetDevice(g_main_device);
5808+
ggml_cuda_set_device(g_main_device);
57995809
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
58005810

58015811
if (src0_on_device) {
@@ -5940,7 +5950,7 @@ static void ggml_cuda_op_mul_mat(
59405950
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
59415951
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
59425952

5943-
cudaSetDevice(id);
5953+
ggml_cuda_set_device(id);
59445954
const cudaStream_t stream = g_cudaStreams[id][0];
59455955

59465956
if (src0_on_device && src0_is_contiguous) {
@@ -5976,7 +5986,7 @@ static void ggml_cuda_op_mul_mat(
59765986
// if multiple devices are used they need to wait for the main device
59775987
// here an event is recorded that signals that the main device has finished calculating the input data
59785988
if (split && g_device_count > 1) {
5979-
CUDA_CHECK(cudaSetDevice(g_main_device));
5989+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
59805990
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
59815991
}
59825992

@@ -5994,7 +6004,7 @@ static void ggml_cuda_op_mul_mat(
59946004
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
59956005
const int64_t row_diff = row_high[id] - row_low[id];
59966006

5997-
cudaSetDevice(id);
6007+
ggml_cuda_set_device(id);
59986008
const cudaStream_t stream = g_cudaStreams[id][is];
59996009

60006010
// wait for main GPU data if necessary
@@ -6096,7 +6106,7 @@ static void ggml_cuda_op_mul_mat(
60966106
}
60976107

60986108
for (int64_t id = 0; id < g_device_count; ++id) {
6099-
CUDA_CHECK(cudaSetDevice(id));
6109+
CUDA_CHECK(ggml_cuda_set_device(id));
61006110

61016111
// free buffers again when done
61026112
if (src0_as[id] > 0) {
@@ -6118,7 +6128,7 @@ static void ggml_cuda_op_mul_mat(
61186128
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
61196129
is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
61206130

6121-
CUDA_CHECK(cudaSetDevice(g_main_device));
6131+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
61226132
for (int64_t id = 0; id < g_device_count; ++id) {
61236133
for (int64_t is = 0; is < is_max; ++is) {
61246134
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
@@ -6127,7 +6137,7 @@ static void ggml_cuda_op_mul_mat(
61276137
}
61286138

61296139
if (dst->backend == GGML_BACKEND_CPU) {
6130-
CUDA_CHECK(cudaSetDevice(g_main_device));
6140+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
61316141
CUDA_CHECK(cudaDeviceSynchronize());
61326142
}
61336143
}
@@ -6187,7 +6197,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
61876197

61886198
const int64_t ne12 = src1->ne[2];
61896199

6190-
CUDA_CHECK(cudaSetDevice(g_main_device));
6200+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
61916201
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
61926202

61936203
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -6218,7 +6228,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
62186228
const int64_t nb01 = src0->nb[1];
62196229
const int64_t nb02 = src0->nb[2];
62206230

6221-
CUDA_CHECK(cudaSetDevice(g_main_device));
6231+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
62226232
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
62236233

62246234
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -6310,7 +6320,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
63106320
const int64_t nb11 = src1->nb[1];
63116321
const int64_t nb12 = src1->nb[2];
63126322

6313-
CUDA_CHECK(cudaSetDevice(g_main_device));
6323+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
63146324
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
63156325

63166326
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
@@ -6376,7 +6386,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
63766386
continue;
63776387
}
63786388

6379-
cudaSetDevice(id);
6389+
ggml_cuda_set_device(id);
63806390

63816391
int64_t row_low, row_high;
63826392
if (backend == GGML_BACKEND_GPU) {
@@ -6446,13 +6456,13 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
64466456

64476457
for (int64_t id = 0; id < g_device_count; ++id) {
64486458
if (extra->data_device[id] != nullptr) {
6449-
CUDA_CHECK(cudaSetDevice(id));
6459+
CUDA_CHECK(ggml_cuda_set_device(id));
64506460
CUDA_CHECK(cudaFree(extra->data_device[id]));
64516461
}
64526462

64536463
for (int64_t is = 0; is < MAX_STREAMS; ++is) {
64546464
if (extra->events[id][is] != nullptr) {
6455-
CUDA_CHECK(cudaSetDevice(id));
6465+
CUDA_CHECK(ggml_cuda_set_device(id));
64566466
CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
64576467
}
64586468
}
@@ -6506,7 +6516,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
65066516
force_inplace;
65076517
const size_t size = ggml_nbytes(tensor);
65086518

6509-
CUDA_CHECK(cudaSetDevice(g_main_device));
6519+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
65106520
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
65116521
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
65126522
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];

0 commit comments

Comments
 (0)