From 4440d198c0aecc2fc4a91a7730917d2689ba0ca8 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Tue, 18 Apr 2023 18:16:27 +0200 Subject: [PATCH 1/3] Add NVIDIA cuBLAS support --- Makefile | 4 ++ ggml.c | 209 ++++++++++++++++++++++++++++++++++++++++++++++++++---- ggml.h | 1 + llama.cpp | 2 +- 4 files changed, 203 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index e7470d51a22..e0f2578dd13 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas LDFLAGS += -lopenblas endif +ifdef LLAMA_CUBLAS + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include + LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 +endif ifdef LLAMA_GPROF CFLAGS += -pg CXXFLAGS += -pg diff --git a/ggml.c b/ggml.c index acdba0333c8..c53c9c4f187 100644 --- a/ggml.c +++ b/ggml.c @@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) { } \ } while (0) -#ifdef GGML_USE_ACCELERATE +#if defined(GGML_USE_ACCELERATE) #include -#elif GGML_USE_OPENBLAS +#elif defined(GGML_USE_OPENBLAS) #include +#elif defined(GGML_USE_CUBLAS) +#include +#include +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +static cublasHandle_t cublasH = NULL; +static cudaStream_t cudaStream = NULL; +static void init_cublas(void) { + if (cublasH == NULL) { + /* step 1: create cublas handle, bind a stream */ + CUBLAS_CHECK(cublasCreate(&cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); + CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); + + // configure logging to stdout + //CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + } +} #endif #undef MIN @@ -3605,6 +3641,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } + // initialize cuBLAS + #if defined(GGML_USE_CUBLAS) + init_cublas(); + #endif + is_first_call = false; } @@ -7161,7 +7202,7 @@ static void ggml_compute_forward_rms_norm( // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster static bool ggml_compute_forward_mul_mat_use_blas( @@ -7201,7 +7242,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -7258,7 +7299,7 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -7272,6 +7313,21 @@ static void ggml_compute_forward_mul_mat_f32( return; } +#if defined(GGML_USE_CUBLAS) + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#endif + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); @@ -7279,15 +7335,38 @@ static void ggml_compute_forward_mul_mat_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); +#if defined(GGML_USE_CUBLAS) + /* step 2: copy data to device */ + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + /* step 3: compute */ + CUBLAS_CHECK( + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, ne00, + d_Y, ne10, + &beta, d_D, ne01)); + + /* step 4: copy data to host */ + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } - +#if defined(GGML_USE_CUBLAS) + /* free resources */ + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -7417,7 +7496,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -7433,10 +7512,37 @@ static void ggml_compute_forward_mul_mat_f16_f32( return; } - float * const wdata = params->wdata; +#if defined(GGML_USE_CUBLAS) + ggml_fp16_t * const wdata = params->wdata; + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#else + float * const wdata = params->wdata; +#endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { +#if defined(GGML_USE_CUBLAS) + // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 + { + size_t id = 0; + for (int64_t i01 = 0; i01 < ne11; ++i01) { + for (int64_t i00 = 0; i00 < ne10; ++i00) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); + } + } + } +#else { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -7445,7 +7551,32 @@ static void ggml_compute_forward_mul_mat_f16_f32( } } } +#endif +#if defined(GGML_USE_CUBLAS) + const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03); + const ggml_fp16_t * y = (ggml_fp16_t *) wdata; + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + /* step 2: copy data to device */ + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + /* step 3: compute */ + CUBLAS_CHECK( + cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, CUDA_R_16F, ne00, + d_Y, CUDA_R_16F, ne10, + &beta, d_D, CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT)); + + /* step 4: copy data to host */ + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7457,9 +7588,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } +#if defined(GGML_USE_CUBLAS) + /* free resources */ + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ return; @@ -7611,7 +7749,7 @@ static void ggml_compute_forward_mul_mat_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -7628,6 +7766,21 @@ static void ggml_compute_forward_mul_mat_q_f32( float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; +#if defined(GGML_USE_CUBLAS) + float *d_X = NULL; + float *d_Y = NULL; + float *d_D = NULL; + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne10; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + + CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); +#endif + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { { @@ -7643,15 +7796,39 @@ static void ggml_compute_forward_mul_mat_q_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); +#if defined(GGML_USE_CUBLAS) + /* step 2: copy data to device */ + CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + /* step 3: compute */ + CUBLAS_CHECK( + cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, d_X, ne00, + d_Y, ne10, + &beta, d_D, ne01)); + + /* step 4: copy data to host */ + CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); +#else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); +#endif } } +#if defined(GGML_USE_CUBLAS) + /* free resources */ + CUDA_CHECK(cudaFree(d_X)); + CUDA_CHECK(cudaFree(d_Y)); + CUDA_CHECK(cudaFree(d_D)); +#endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -10466,7 +10643,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning @@ -10483,7 +10660,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); @@ -11800,7 +11977,15 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_cublas(void) { +#if defined(GGML_USE_CUBLAS) return 1; #else return 0; diff --git a/ggml.h b/ggml.h index 59de0cb1257..f436fbd56d9 100644 --- a/ggml.h +++ b/ggml.h @@ -823,6 +823,7 @@ int ggml_cpu_has_f16c(void); int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_blas(void); +int ggml_cpu_has_cublas(void); int ggml_cpu_has_sse3(void); int ggml_cpu_has_vsx(void); diff --git a/llama.cpp b/llama.cpp index db71c03bdd4..4a5b6b068f8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1066,7 +1066,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads; + gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); From 5fc6799f05926808f64db44c2e7a7c8a5d798ab6 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Tue, 18 Apr 2023 23:20:11 +0200 Subject: [PATCH 2/3] Add support to cmake --- CMakeLists.txt | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9189b6fc20b..261dd71ef08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason +cmake_minimum_required(VERSION 3.17) # Don't bump this version for no reason project("llama.cpp" C CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -66,6 +66,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -142,6 +143,24 @@ if (LLAMA_OPENBLAS) endif() endif() +if (LLAMA_CUBLAS) + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + add_compile_definitions(GGML_USE_CUBLAS) + + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags From 40846bd28d7909c7de96709fe3563c75b4bcc0a3 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Wed, 19 Apr 2023 00:37:33 +0200 Subject: [PATCH 3/3] Cleanup cublas comments --- CMakeLists.txt | 4 +++- ggml.c | 27 ++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 261dd71ef08..1641a96157a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.17) # Don't bump this version for no reason +cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason project("llama.cpp" C CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -144,6 +144,8 @@ if (LLAMA_OPENBLAS) endif() if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + find_package(CUDAToolkit) if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") diff --git a/ggml.c b/ggml.c index c53c9c4f187..6cd7619dd46 100644 --- a/ggml.c +++ b/ggml.c @@ -172,14 +172,14 @@ static cublasHandle_t cublasH = NULL; static cudaStream_t cudaStream = NULL; static void init_cublas(void) { if (cublasH == NULL) { - /* step 1: create cublas handle, bind a stream */ + // create cublas handle, bind a stream CUBLAS_CHECK(cublasCreate(&cublasH)); CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); // configure logging to stdout - //CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } #endif @@ -7336,11 +7336,11 @@ static void ggml_compute_forward_mul_mat_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); #if defined(GGML_USE_CUBLAS) - /* step 2: copy data to device */ + // copy data to device CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); - /* step 3: compute */ + // compute CUBLAS_CHECK( cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7348,7 +7348,7 @@ static void ggml_compute_forward_mul_mat_f32( d_Y, ne10, &beta, d_D, ne01)); - /* step 4: copy data to host */ + // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else @@ -7362,7 +7362,6 @@ static void ggml_compute_forward_mul_mat_f32( } } #if defined(GGML_USE_CUBLAS) - /* free resources */ CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D)); @@ -7533,7 +7532,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { #if defined(GGML_USE_CUBLAS) - // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 + // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 { size_t id = 0; for (int64_t i01 = 0; i01 < ne11; ++i01) { @@ -7559,11 +7558,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - /* step 2: copy data to device */ + // copy data to device CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); - /* step 3: compute */ + // compute CUBLAS_CHECK( cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7573,7 +7572,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); - /* step 4: copy data to host */ + // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else @@ -7593,7 +7592,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) - /* free resources */ CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D)); @@ -7797,11 +7795,11 @@ static void ggml_compute_forward_mul_mat_q_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); #if defined(GGML_USE_CUBLAS) - /* step 2: copy data to device */ + // copy data to device CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); - /* step 3: compute */ + // compute CUBLAS_CHECK( cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7809,7 +7807,7 @@ static void ggml_compute_forward_mul_mat_q_f32( d_Y, ne10, &beta, d_D, ne01)); - /* step 4: copy data to host */ + // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else @@ -7824,7 +7822,6 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) - /* free resources */ CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D));