Skip to content

Commit 8a93a05

Browse files
Only one CUDA stream per device for async compute
1 parent 3d59ec5 commit 8a93a05

File tree

3 files changed

+20
-38
lines changed

3 files changed

+20
-38
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements
336336
cmake .. -DLLAMA_CUBLAS=ON
337337
cmake --build . --config Release
338338
```
339-
Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1.
340339
341340
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
342341

examples/common.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
106106
}
107107

108108
if (arg == "-s" || arg == "--seed") {
109-
#if defined(GGML_USE_CUBLAS)
110-
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
111-
#endif
112109
if (++i >= argc) {
113110
invalid_param = true;
114111
break;

ggml-cuda.cu

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,19 +1467,13 @@ static void * g_scratch_buffer = nullptr;
14671467
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
14681468
static size_t g_scratch_offset = 0;
14691469

1470-
#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
1471-
#define GGML_CUDA_MAX_EVENTS 64
1472-
14731470
static int g_device_count = -1;
14741471
static int g_main_device = 0;
14751472
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
14761473

14771474
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
14781475

1479-
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1480-
1481-
static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1482-
static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
1476+
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
14831477

14841478
void ggml_init_cublas() {
14851479
static bool initialized = false;
@@ -1503,15 +1497,8 @@ void ggml_init_cublas() {
15031497
for (int id = 0; id < g_device_count; ++id) {
15041498
CUDA_CHECK(cudaSetDevice(id));
15051499

1506-
// create streams
1507-
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
1508-
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
1509-
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
1510-
}
1511-
// create events
1512-
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
1513-
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
1514-
}
1500+
// create main stream
1501+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
15151502

15161503
// create cublas handle
15171504
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -1978,6 +1965,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
19781965
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
19791966
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
19801967

1968+
// if multiple GPUs are used they need to wait for the main GPU to finish
1969+
if (split && g_device_count > 1) {
1970+
CUDA_CHECK(cudaSetDevice(g_main_device));
1971+
CUDA_CHECK(cudaDeviceSynchronize());
1972+
}
1973+
19811974
for (int id = 0; id < g_device_count; ++id) {
19821975
if (!split && id != g_main_device) {
19831976
continue;
@@ -2076,9 +2069,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
20762069
}
20772070
const int64_t i11 = i13*ne12 + i12;
20782071

2079-
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
2080-
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
2081-
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
2072+
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
20822073

20832074
// for split tensors the data begins at i0 == i0_offset_low
20842075
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -2106,14 +2097,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
21062097
if (src1->backend == GGML_BACKEND_CPU) {
21072098
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
21082099
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
2109-
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
2100+
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
21102101
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
21112102
if (id != g_main_device) {
21122103
GGML_ASSERT(!flatten_rows);
21132104
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
21142105
src1_ddf_i_source += i11*src1_stride;
21152106
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
2116-
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
2107+
cudaMemcpyDeviceToDevice, cudaStream_main));
21172108
}
21182109
} else if (src1_on_device && !src1_is_contiguous) {
21192110
GGML_ASSERT(!split);
@@ -2122,7 +2113,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
21222113
GGML_ASSERT(false);
21232114
}
21242115
}
2125-
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
21262116

21272117
if (!src0_on_device || !src0_is_contiguous) {
21282118
if (src0_is_f32) {
@@ -2138,9 +2128,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
21382128
CUDA_CHECK(cudaGetLastError());
21392129
}
21402130

2141-
// wait with main stream until src1 memcpy is done
2142-
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
2143-
21442131
// do the computation
21452132
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
21462133

@@ -2178,8 +2165,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
21782165

21792166
// wait until each device is finished, then free their buffers
21802167
for (int id = 0; id < g_device_count; ++id) {
2168+
if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
2169+
continue;
2170+
}
2171+
21812172
CUDA_CHECK(cudaSetDevice(id));
21822173
CUDA_CHECK(cudaDeviceSynchronize());
2174+
21832175
if (src0_asq[id] > 0) {
21842176
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
21852177
}
@@ -2245,7 +2237,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
22452237
const int64_t ne02 = src0->ne[2];
22462238

22472239
CUDA_CHECK(cudaSetDevice(g_main_device));
2248-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2240+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
22492241

22502242
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
22512243
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2257,8 +2249,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
22572249
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
22582250

22592251
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
2260-
2261-
CUDA_CHECK(cudaDeviceSynchronize());
22622252
}
22632253

22642254
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2276,7 +2266,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
22762266
const int64_t nb02 = src0->nb[2];
22772267

22782268
CUDA_CHECK(cudaSetDevice(g_main_device));
2279-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2269+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
22802270

22812271
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
22822272
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2291,8 +2281,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
22912281
const int channel_stride_x = nb02 / sizeof(half);
22922282

22932283
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
2294-
2295-
CUDA_CHECK(cudaDeviceSynchronize());
22962284
}
22972285

22982286
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2348,7 +2336,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
23482336
const int64_t nb12 = src1->nb[2];
23492337

23502338
CUDA_CHECK(cudaSetDevice(g_main_device));
2351-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2339+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
23522340

23532341
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
23542342
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2366,8 +2354,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
23662354
GGML_ASSERT(false);
23672355
}
23682356

2369-
CUDA_CHECK(cudaDeviceSynchronize());
2370-
23712357
(void) dst;
23722358
}
23732359

0 commit comments

Comments
 (0)