@@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
142142 } \
143143 } while (0)
144144
145- #ifdef GGML_USE_ACCELERATE
145+ #if defined( GGML_USE_ACCELERATE )
146146#include <Accelerate/Accelerate.h>
147- #elif GGML_USE_OPENBLAS
147+ #elif defined( GGML_USE_OPENBLAS )
148148#include <cblas.h>
149+ #elif defined(GGML_USE_CUBLAS )
150+ #include <cublas_v2.h>
151+ #include <cuda_runtime.h>
152+ #define CUDA_CHECK (err ) \
153+ do { \
154+ cudaError_t err_ = (err); \
155+ if (err_ != cudaSuccess) { \
156+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
157+ cudaGetErrorString(err_)); \
158+ exit(1); \
159+ } \
160+ } while (0)
161+
162+ #define CUBLAS_CHECK (err ) \
163+ do { \
164+ cublasStatus_t err_ = (err); \
165+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
166+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
167+ exit(1); \
168+ } \
169+ } while (0)
170+
171+ static cublasHandle_t cublasH = NULL ;
172+ static cudaStream_t cudaStream = NULL ;
173+ static void init_cublas (void ) {
174+ if (cublasH == NULL ) {
175+ // create cublas handle, bind a stream
176+ CUBLAS_CHECK (cublasCreate (& cublasH ));
177+
178+ CUDA_CHECK (cudaStreamCreateWithFlags (& cudaStream , cudaStreamNonBlocking ));
179+ CUBLAS_CHECK (cublasSetStream (cublasH , cudaStream ));
180+
181+ // configure logging to stdout
182+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
183+ }
184+ }
149185#endif
150186
151187#undef MIN
@@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38363872 GGML_PRINT_DEBUG ("%s: g_state initialized in %f ms\n" , __func__ , (t_end - t_start )/1000.0f );
38373873 }
38383874
3875+ // initialize cuBLAS
3876+ #if defined(GGML_USE_CUBLAS )
3877+ init_cublas ();
3878+ #endif
3879+
38393880 is_first_call = false;
38403881 }
38413882
@@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm(
75677608
75687609// ggml_compute_forward_mul_mat
75697610
7570- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7611+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
75717612// helper function to determine if it is better to use BLAS or not
75727613// for large matrices, BLAS is faster
75737614static bool ggml_compute_forward_mul_mat_use_blas (
@@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32(
76077648 const int64_t ne02 = src0 -> ne [2 ];
76087649 const int64_t ne03 = src0 -> ne [3 ];
76097650
7610- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7651+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
76117652 const int64_t ne10 = src1 -> ne [0 ];
76127653#endif
76137654 const int64_t ne11 = src1 -> ne [1 ];
@@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32(
76647705 // nb01 >= nb00 - src0 is not transposed
76657706 // compute by src0 rows
76667707
7667- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7708+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
76687709 if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
76697710 if (params -> ith != 0 ) {
76707711 return ;
@@ -7678,22 +7719,59 @@ static void ggml_compute_forward_mul_mat_f32(
76787719 return ;
76797720 }
76807721
7722+ #if defined(GGML_USE_CUBLAS )
7723+ float * d_X = NULL ;
7724+ float * d_Y = NULL ;
7725+ float * d_D = NULL ;
7726+ const float alpha = 1.0f ;
7727+ const float beta = 0.0f ;
7728+ const int x_ne = ne01 * ne10 ;
7729+ const int y_ne = ne11 * ne10 ;
7730+ const int d_ne = ne11 * ne01 ;
7731+
7732+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (float ) * x_ne ));
7733+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
7734+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
7735+ #endif
7736+
76817737 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
76827738 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
76837739 const float * x = (float * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
76847740 const float * y = (float * ) ((char * ) src1 -> data + i02 * nb12 + i03 * nb13 );
76857741
76867742 float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
76877743
7744+ #if defined(GGML_USE_CUBLAS )
7745+ // copy data to device
7746+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (float ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
7747+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
7748+
7749+ // compute
7750+ CUBLAS_CHECK (
7751+ cublasSgemm (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
7752+ ne01 , ne11 , ne10 ,
7753+ & alpha , d_X , ne00 ,
7754+ d_Y , ne10 ,
7755+ & beta , d_D , ne01 ));
7756+
7757+ // copy data to host
7758+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
7759+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
7760+ #else
76887761 // zT = y * xT
76897762 cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasTrans ,
76907763 ne11 , ne01 , ne10 ,
76917764 1.0f , y , ne10 ,
76927765 x , ne00 ,
76937766 0.0f , d , ne01 );
7767+ #endif
76947768 }
76957769 }
7696-
7770+ #if defined(GGML_USE_CUBLAS )
7771+ CUDA_CHECK (cudaFree (d_X ));
7772+ CUDA_CHECK (cudaFree (d_Y ));
7773+ CUDA_CHECK (cudaFree (d_D ));
7774+ #endif
76977775 //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
76987776
76997777 return ;
@@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
78237901 // nb01 >= nb00 - src0 is not transposed
78247902 // compute by src0 rows
78257903
7826- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7904+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
78277905 if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
78287906 GGML_ASSERT (nb10 == sizeof (float ));
78297907
@@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
78397917 return ;
78407918 }
78417919
7842- float * const wdata = params -> wdata ;
7920+ #if defined(GGML_USE_CUBLAS )
7921+ ggml_fp16_t * const wdata = params -> wdata ;
78437922
7923+ float * d_X = NULL ;
7924+ float * d_Y = NULL ;
7925+ float * d_D = NULL ;
7926+ const float alpha = 1.0f ;
7927+ const float beta = 0.0f ;
7928+ const int x_ne = ne01 * ne10 ;
7929+ const int y_ne = ne11 * ne10 ;
7930+ const int d_ne = ne11 * ne01 ;
7931+
7932+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (ggml_fp16_t ) * x_ne ));
7933+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
7934+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
7935+ #else
7936+ float * const wdata = params -> wdata ;
7937+ #endif
78447938 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
78457939 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
7940+ #if defined(GGML_USE_CUBLAS )
7941+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7942+ {
7943+ size_t id = 0 ;
7944+ for (int64_t i01 = 0 ; i01 < ne11 ; ++ i01 ) {
7945+ for (int64_t i00 = 0 ; i00 < ne10 ; ++ i00 ) {
7946+ wdata [id ++ ] = GGML_FP32_TO_FP16 (* (float * ) ((char * ) src1 -> data + i03 * nb13 + i02 * nb12 + i01 * nb11 + i00 * nb10 ));
7947+ }
7948+ }
7949+ }
7950+ #else
78467951 {
78477952 size_t id = 0 ;
78487953 for (int64_t i01 = 0 ; i01 < ne01 ; ++ i01 ) {
@@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
78517956 }
78527957 }
78537958 }
7959+ #endif
78547960
7961+ #if defined(GGML_USE_CUBLAS )
7962+ const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
7963+ const ggml_fp16_t * y = (ggml_fp16_t * ) wdata ;
7964+
7965+ float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
7966+
7967+ // copy data to device
7968+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
7969+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (ggml_fp16_t ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
7970+
7971+ // compute
7972+ CUBLAS_CHECK (
7973+ cublasGemmEx (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
7974+ ne01 , ne11 , ne10 ,
7975+ & alpha , d_X , CUDA_R_16F , ne00 ,
7976+ d_Y , CUDA_R_16F , ne10 ,
7977+ & beta , d_D , CUDA_R_32F , ne01 ,
7978+ CUBLAS_COMPUTE_32F ,
7979+ CUBLAS_GEMM_DEFAULT ));
7980+
7981+ // copy data to host
7982+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
7983+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
7984+ #else
78557985 const float * x = wdata ;
78567986 const float * y = (float * ) ((char * ) src1 -> data + i02 * nb12 + i03 * nb13 );
78577987
@@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32(
78637993 1.0f , y , ne10 ,
78647994 x , ne00 ,
78657995 0.0f , d , ne01 );
7996+ #endif
78667997 }
78677998 }
78687999
8000+ #if defined(GGML_USE_CUBLAS )
8001+ CUDA_CHECK (cudaFree (d_X ));
8002+ CUDA_CHECK (cudaFree (d_Y ));
8003+ CUDA_CHECK (cudaFree (d_D ));
8004+ #endif
78698005 /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
78708006
78718007 return ;
@@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
80178153 // nb01 >= nb00 - src0 is not transposed
80188154 // compute by src0 rows
80198155
8020- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
8156+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
80218157 if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
80228158 if (params -> ith != 0 ) {
80238159 return ;
@@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
80348170 float * const wdata = params -> wdata ;
80358171 dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
80368172
8173+ #if defined(GGML_USE_CUBLAS )
8174+ float * d_X = NULL ;
8175+ float * d_Y = NULL ;
8176+ float * d_D = NULL ;
8177+ const float alpha = 1.0f ;
8178+ const float beta = 0.0f ;
8179+ const int x_ne = ne01 * ne10 ;
8180+ const int y_ne = ne11 * ne10 ;
8181+ const int d_ne = ne11 * ne01 ;
8182+
8183+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (float ) * x_ne ));
8184+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
8185+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
8186+ #endif
8187+
80378188 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
80388189 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
80398190 {
@@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32(
80498200
80508201 float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
80518202
8203+ #if defined(GGML_USE_CUBLAS )
8204+ // copy data to device
8205+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (float ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
8206+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
8207+
8208+ // compute
8209+ CUBLAS_CHECK (
8210+ cublasSgemm (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
8211+ ne01 , ne11 , ne10 ,
8212+ & alpha , d_X , ne00 ,
8213+ d_Y , ne10 ,
8214+ & beta , d_D , ne01 ));
8215+
8216+ // copy data to host
8217+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
8218+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
8219+ #else
80528220 // zT = y * xT
80538221 cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasTrans ,
80548222 ne11 , ne01 , ne10 ,
80558223 1.0f , y , ne10 ,
80568224 x , ne00 ,
80578225 0.0f , d , ne01 );
8226+ #endif
80588227 }
80598228 }
80608229
8230+ #if defined(GGML_USE_CUBLAS )
8231+ CUDA_CHECK (cudaFree (d_X ));
8232+ CUDA_CHECK (cudaFree (d_Y ));
8233+ CUDA_CHECK (cudaFree (d_D ));
8234+ #endif
80618235 //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
80628236
80638237 return ;
@@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1087411048 size_t cur = 0 ;
1087511049
1087611050 if (node -> src0 -> type == GGML_TYPE_F16 && node -> src1 -> type == GGML_TYPE_F32 ) {
10877- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
11051+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
1087811052 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
1087911053 node -> n_tasks = 1 ; // TODO: this actually is doing nothing
1088011054 // the threads are still spinning
@@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1089111065 } else if (node -> src0 -> type == GGML_TYPE_F32 && node -> src1 -> type == GGML_TYPE_F32 ) {
1089211066 cur = 0 ;
1089311067 } else if (quantize_fns [node -> src0 -> type ].vec_dot_q && node -> src1 -> type == GGML_TYPE_F32 ) {
10894- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
11068+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
1089511069 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
1089611070 node -> n_tasks = 1 ;
1089711071 cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
@@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) {
1223112405}
1223212406
1223312407int ggml_cpu_has_blas (void ) {
12234- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
12408+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS )
12409+ return 1 ;
12410+ #else
12411+ return 0 ;
12412+ #endif
12413+ }
12414+
12415+ int ggml_cpu_has_cublas (void ) {
12416+ #if defined(GGML_USE_CUBLAS )
1223512417 return 1 ;
1223612418#else
1223712419 return 0 ;
0 commit comments