Skip to content

Commit 9ae9b86

Browse files
CUDA: enable peer access between devices
1 parent 70ffed6 commit 9ae9b86

File tree

6 files changed

+66
-10
lines changed

6 files changed

+66
-10
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern
8080
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
8181
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
8282
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
83+
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
84+
"llama: max. batch size for using peer access")
8385
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
8486
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
8587
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
@@ -304,6 +306,7 @@ if (LLAMA_CUBLAS)
304306
add_compile_definitions(GGML_CUDA_F16)
305307
endif()
306308
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
309+
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
307310

308311
if (LLAMA_STATIC)
309312
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
368368
else
369369
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
370370
endif
371+
ifdef LLAMA_CUDA_PEER_MAX_BATCH_SIZE
372+
NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(LLAMA_CUDA_PEER_MAX_BATCH_SIZE)
373+
else
374+
NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128
375+
endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
371376
#ifdef LLAMA_CUDA_CUBLAS
372377
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
373378
#endif # LLAMA_CUDA_CUBLAS

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,13 +391,14 @@ Building the program with BLAS support may lead to some performance improvements
391391
<!---
392392
| LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
393393
--->
394-
| Option | Legal values | Default | Description |
395-
|-------------------------|------------------------|---------|-------------|
396-
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
397-
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
398-
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
399-
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
400-
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
394+
| Option | Legal values | Default | Description |
395+
|--------------------------------|------------------------|---------|-------------|
396+
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
397+
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
398+
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. |
399+
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
400+
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
401+
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
401402

402403
- #### hipBLAS
403404

ggml-cuda.cu

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
#define cublasSetStream hipblasSetStream
3232
#define cublasSgemm hipblasSgemm
3333
#define cublasStatus_t hipblasStatus_t
34+
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
35+
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
36+
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
3437
#define cudaDeviceProp hipDeviceProp_t
3538
#define cudaDeviceSynchronize hipDeviceSynchronize
3639
#define cudaError_t hipError_t
@@ -424,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
424427
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
425428
#endif
426429

430+
#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
431+
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
432+
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
433+
427434
#define MUL_MAT_SRC1_COL_STRIDE 128
428435

429436
#define MAX_STREAMS 8
@@ -7012,7 +7019,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
70127019
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
70137020
}
70147021

7015-
void ggml_cuda_set_main_device(int main_device) {
7022+
void ggml_cuda_set_main_device(const int main_device) {
70167023
if (main_device >= g_device_count) {
70177024
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
70187025
main_device, g_device_count, g_main_device);
@@ -7026,14 +7033,49 @@ void ggml_cuda_set_main_device(int main_device) {
70267033
}
70277034
}
70287035

7029-
void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
7036+
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
70307037
g_mul_mat_q = mul_mat_q;
70317038
}
70327039

7033-
void ggml_cuda_set_scratch_size(size_t scratch_size) {
7040+
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
70347041
g_scratch_size = scratch_size;
70357042
}
70367043

7044+
void ggml_cuda_set_peer_access(const int n_tokens) {
7045+
static int last_n_tokens = INT_MAX;
7046+
7047+
if (n_tokens == last_n_tokens) {
7048+
return;
7049+
}
7050+
7051+
#ifdef NDEBUG
7052+
for (int id = 0; id < g_device_count; ++id) {
7053+
CUDA_CHECK(cudaSetDevice(id));
7054+
7055+
for (int id_other = 0; id_other < g_device_count; ++id_other) {
7056+
if (id == id_other) {
7057+
continue;
7058+
}
7059+
if (id != g_main_device && id_other != g_main_device) {
7060+
continue;
7061+
}
7062+
7063+
int canAccessPeer;
7064+
CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, id, id_other));
7065+
if (canAccessPeer) {
7066+
if (n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE && last_n_tokens > GGML_CUDA_PEER_MAX_BATCH_SIZE) {
7067+
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
7068+
} else if (n_tokens > GGML_CUDA_PEER_MAX_BATCH_SIZE && last_n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE) {
7069+
CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
7070+
}
7071+
}
7072+
}
7073+
}
7074+
#endif // NDEBUG
7075+
7076+
last_n_tokens = n_tokens;
7077+
}
7078+
70377079
void ggml_cuda_free_scratch() {
70387080
if (g_scratch_buffer == nullptr) {
70397081
return;

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, siz
3535
GGML_API void ggml_cuda_set_main_device(int main_device);
3636
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
3737
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
38+
GGML_API void ggml_cuda_set_peer_access(int n_tokens);
3839
GGML_API void ggml_cuda_free_scratch(void);
3940
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
4041

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3718,6 +3718,10 @@ static bool llama_eval_internal(
37183718

37193719
const int64_t t_start_us = ggml_time_us();
37203720

3721+
#ifdef GGML_USE_CUBLAS
3722+
ggml_cuda_set_peer_access(n_tokens);
3723+
#endif // GGML_USE_CUBLAS
3724+
37213725
#ifdef GGML_USE_MPI
37223726
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
37233727
#endif

0 commit comments

Comments
 (0)