Skip to content

Commit 4e85b43

Browse files
Only one CUDA stream per device for async compute
1 parent a09f919 commit 4e85b43

File tree

4 files changed

+20
-41
lines changed

4 files changed

+20
-41
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements
336336
cmake .. -DLLAMA_CUBLAS=ON
337337
cmake --build . --config Release
338338
```
339-
Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1.
340339
341340
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
342341

examples/common.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
102102
}
103103

104104
if (arg == "-s" || arg == "--seed") {
105-
#if defined(GGML_USE_CUBLAS)
106-
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
107-
#endif
108105
if (++i >= argc) {
109106
invalid_param = true;
110107
break;

examples/server/server.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,6 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
445445
}
446446
else if (arg == "-s" || arg == "--seed")
447447
{
448-
#if defined(GGML_USE_CUBLAS)
449-
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
450-
#endif
451448
if (++i >= argc)
452449
{
453450
invalid_param = true;

ggml-cuda.cu

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,19 +1306,13 @@ static void * g_scratch_buffer = nullptr;
13061306
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
13071307
static size_t g_scratch_offset = 0;
13081308

1309-
#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
1310-
#define GGML_CUDA_MAX_EVENTS 64
1311-
13121309
static int g_device_count = -1;
13131310
static int g_main_device = 0;
13141311
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
13151312

13161313
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
13171314

1318-
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1319-
1320-
static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1321-
static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
1315+
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
13221316

13231317
void ggml_init_cublas() {
13241318
static bool initialized = false;
@@ -1342,15 +1336,8 @@ void ggml_init_cublas() {
13421336
for (int id = 0; id < g_device_count; ++id) {
13431337
CUDA_CHECK(cudaSetDevice(id));
13441338

1345-
// create streams
1346-
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
1347-
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
1348-
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
1349-
}
1350-
// create events
1351-
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
1352-
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
1353-
}
1339+
// create main stream
1340+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
13541341

13551342
// create cublas handle
13561343
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -1817,6 +1804,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
18171804
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
18181805
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
18191806

1807+
// if multiple GPUs are used they need to wait for the main GPU to finish
1808+
if (split && g_device_count > 1) {
1809+
CUDA_CHECK(cudaSetDevice(g_main_device));
1810+
CUDA_CHECK(cudaDeviceSynchronize());
1811+
}
1812+
18201813
for (int id = 0; id < g_device_count; ++id) {
18211814
if (!split && id != g_main_device) {
18221815
continue;
@@ -1915,9 +1908,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
19151908
}
19161909
const int64_t i11 = i13*ne12 + i12;
19171910

1918-
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1919-
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
1920-
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
1911+
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
19211912

19221913
// for split tensors the data begins at i0 == i0_offset_low
19231914
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -1945,14 +1936,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
19451936
if (src1->backend == GGML_BACKEND_CPU) {
19461937
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
19471938
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
1948-
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
1939+
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
19491940
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
19501941
if (id != g_main_device) {
19511942
GGML_ASSERT(!flatten_rows);
19521943
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
19531944
src1_ddf_i_source += i11*src1_stride;
19541945
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
1955-
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
1946+
cudaMemcpyDeviceToDevice, cudaStream_main));
19561947
}
19571948
} else if (src1_on_device && !src1_is_contiguous) {
19581949
GGML_ASSERT(!split);
@@ -1961,7 +1952,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
19611952
GGML_ASSERT(false);
19621953
}
19631954
}
1964-
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
19651955

19661956
if (!src0_on_device || !src0_is_contiguous) {
19671957
if (src0_is_f32) {
@@ -1977,9 +1967,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
19771967
CUDA_CHECK(cudaGetLastError());
19781968
}
19791969

1980-
// wait with main stream until src1 memcpy is done
1981-
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
1982-
19831970
// do the computation
19841971
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
19851972

@@ -2017,8 +2004,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
20172004

20182005
// wait until each device is finished, then free their buffers
20192006
for (int id = 0; id < g_device_count; ++id) {
2007+
if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
2008+
continue;
2009+
}
2010+
20202011
CUDA_CHECK(cudaSetDevice(id));
20212012
CUDA_CHECK(cudaDeviceSynchronize());
2013+
20222014
if (src0_asq[id] > 0) {
20232015
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
20242016
}
@@ -2084,7 +2076,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
20842076
const int64_t ne02 = src0->ne[2];
20852077

20862078
CUDA_CHECK(cudaSetDevice(g_main_device));
2087-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2079+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
20882080

20892081
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
20902082
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2096,8 +2088,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
20962088
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
20972089

20982090
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
2099-
2100-
CUDA_CHECK(cudaDeviceSynchronize());
21012091
}
21022092

21032093
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2115,7 +2105,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
21152105
const int64_t nb02 = src0->nb[2];
21162106

21172107
CUDA_CHECK(cudaSetDevice(g_main_device));
2118-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2108+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
21192109

21202110
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
21212111
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2130,8 +2120,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
21302120
const int channel_stride_x = nb02 / sizeof(half);
21312121

21322122
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
2133-
2134-
CUDA_CHECK(cudaDeviceSynchronize());
21352123
}
21362124

21372125
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2187,7 +2175,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
21872175
const int64_t nb12 = src1->nb[2];
21882176

21892177
CUDA_CHECK(cudaSetDevice(g_main_device));
2190-
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2178+
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
21912179

21922180
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
21932181
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2205,8 +2193,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
22052193
GGML_ASSERT(false);
22062194
}
22072195

2208-
CUDA_CHECK(cudaDeviceSynchronize());
2209-
22102196
(void) dst;
22112197
}
22122198

0 commit comments

Comments
 (0)