@@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
142
142
} \
143
143
} while (0)
144
144
145
- #ifdef GGML_USE_ACCELERATE
145
+ #if defined( GGML_USE_ACCELERATE )
146
146
#include <Accelerate/Accelerate.h>
147
- #elif GGML_USE_OPENBLAS
147
+ #elif defined( GGML_USE_OPENBLAS )
148
148
#include <cblas.h>
149
+ #elif defined(GGML_USE_CUBLAS )
150
+ #include <cublas_v2.h>
151
+ #include <cuda_runtime.h>
152
+ #define CUDA_CHECK (err ) \
153
+ do { \
154
+ cudaError_t err_ = (err); \
155
+ if (err_ != cudaSuccess) { \
156
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
157
+ cudaGetErrorString(err_)); \
158
+ exit(1); \
159
+ } \
160
+ } while (0)
161
+
162
+ #define CUBLAS_CHECK (err ) \
163
+ do { \
164
+ cublasStatus_t err_ = (err); \
165
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
166
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
167
+ exit(1); \
168
+ } \
169
+ } while (0)
170
+
171
+ static cublasHandle_t cublasH = NULL ;
172
+ static cudaStream_t cudaStream = NULL ;
173
+ static void init_cublas (void ) {
174
+ if (cublasH == NULL ) {
175
+ // create cublas handle, bind a stream
176
+ CUBLAS_CHECK (cublasCreate (& cublasH ));
177
+
178
+ CUDA_CHECK (cudaStreamCreateWithFlags (& cudaStream , cudaStreamNonBlocking ));
179
+ CUBLAS_CHECK (cublasSetStream (cublasH , cudaStream ));
180
+
181
+ // configure logging to stdout
182
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
183
+ }
184
+ }
149
185
#endif
150
186
151
187
#undef MIN
@@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3836
3872
GGML_PRINT_DEBUG ("%s: g_state initialized in %f ms\n" , __func__ , (t_end - t_start )/1000.0f );
3837
3873
}
3838
3874
3875
+ // initialize cuBLAS
3876
+ #if defined(GGML_USE_CUBLAS )
3877
+ init_cublas ();
3878
+ #endif
3879
+
3839
3880
is_first_call = false;
3840
3881
}
3841
3882
@@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm(
7567
7608
7568
7609
// ggml_compute_forward_mul_mat
7569
7610
7570
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7611
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
7571
7612
// helper function to determine if it is better to use BLAS or not
7572
7613
// for large matrices, BLAS is faster
7573
7614
static bool ggml_compute_forward_mul_mat_use_blas (
@@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32(
7607
7648
const int64_t ne02 = src0 -> ne [2 ];
7608
7649
const int64_t ne03 = src0 -> ne [3 ];
7609
7650
7610
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7651
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
7611
7652
const int64_t ne10 = src1 -> ne [0 ];
7612
7653
#endif
7613
7654
const int64_t ne11 = src1 -> ne [1 ];
@@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32(
7664
7705
// nb01 >= nb00 - src0 is not transposed
7665
7706
// compute by src0 rows
7666
7707
7667
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7708
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
7668
7709
if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
7669
7710
if (params -> ith != 0 ) {
7670
7711
return ;
@@ -7678,22 +7719,59 @@ static void ggml_compute_forward_mul_mat_f32(
7678
7719
return ;
7679
7720
}
7680
7721
7722
+ #if defined(GGML_USE_CUBLAS )
7723
+ float * d_X = NULL ;
7724
+ float * d_Y = NULL ;
7725
+ float * d_D = NULL ;
7726
+ const float alpha = 1.0f ;
7727
+ const float beta = 0.0f ;
7728
+ const int x_ne = ne01 * ne10 ;
7729
+ const int y_ne = ne11 * ne10 ;
7730
+ const int d_ne = ne11 * ne01 ;
7731
+
7732
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (float ) * x_ne ));
7733
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
7734
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
7735
+ #endif
7736
+
7681
7737
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
7682
7738
for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
7683
7739
const float * x = (float * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
7684
7740
const float * y = (float * ) ((char * ) src1 -> data + i02 * nb12 + i03 * nb13 );
7685
7741
7686
7742
float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
7687
7743
7744
+ #if defined(GGML_USE_CUBLAS )
7745
+ // copy data to device
7746
+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (float ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
7747
+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
7748
+
7749
+ // compute
7750
+ CUBLAS_CHECK (
7751
+ cublasSgemm (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
7752
+ ne01 , ne11 , ne10 ,
7753
+ & alpha , d_X , ne00 ,
7754
+ d_Y , ne10 ,
7755
+ & beta , d_D , ne01 ));
7756
+
7757
+ // copy data to host
7758
+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
7759
+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
7760
+ #else
7688
7761
// zT = y * xT
7689
7762
cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasTrans ,
7690
7763
ne11 , ne01 , ne10 ,
7691
7764
1.0f , y , ne10 ,
7692
7765
x , ne00 ,
7693
7766
0.0f , d , ne01 );
7767
+ #endif
7694
7768
}
7695
7769
}
7696
-
7770
+ #if defined(GGML_USE_CUBLAS )
7771
+ CUDA_CHECK (cudaFree (d_X ));
7772
+ CUDA_CHECK (cudaFree (d_Y ));
7773
+ CUDA_CHECK (cudaFree (d_D ));
7774
+ #endif
7697
7775
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7698
7776
7699
7777
return ;
@@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7823
7901
// nb01 >= nb00 - src0 is not transposed
7824
7902
// compute by src0 rows
7825
7903
7826
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
7904
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
7827
7905
if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
7828
7906
GGML_ASSERT (nb10 == sizeof (float ));
7829
7907
@@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7839
7917
return ;
7840
7918
}
7841
7919
7842
- float * const wdata = params -> wdata ;
7920
+ #if defined(GGML_USE_CUBLAS )
7921
+ ggml_fp16_t * const wdata = params -> wdata ;
7843
7922
7923
+ float * d_X = NULL ;
7924
+ float * d_Y = NULL ;
7925
+ float * d_D = NULL ;
7926
+ const float alpha = 1.0f ;
7927
+ const float beta = 0.0f ;
7928
+ const int x_ne = ne01 * ne10 ;
7929
+ const int y_ne = ne11 * ne10 ;
7930
+ const int d_ne = ne11 * ne01 ;
7931
+
7932
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (ggml_fp16_t ) * x_ne ));
7933
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
7934
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
7935
+ #else
7936
+ float * const wdata = params -> wdata ;
7937
+ #endif
7844
7938
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
7845
7939
for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
7940
+ #if defined(GGML_USE_CUBLAS )
7941
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7942
+ {
7943
+ size_t id = 0 ;
7944
+ for (int64_t i01 = 0 ; i01 < ne11 ; ++ i01 ) {
7945
+ for (int64_t i00 = 0 ; i00 < ne10 ; ++ i00 ) {
7946
+ wdata [id ++ ] = GGML_FP32_TO_FP16 (* (float * ) ((char * ) src1 -> data + i03 * nb13 + i02 * nb12 + i01 * nb11 + i00 * nb10 ));
7947
+ }
7948
+ }
7949
+ }
7950
+ #else
7846
7951
{
7847
7952
size_t id = 0 ;
7848
7953
for (int64_t i01 = 0 ; i01 < ne01 ; ++ i01 ) {
@@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7851
7956
}
7852
7957
}
7853
7958
}
7959
+ #endif
7854
7960
7961
+ #if defined(GGML_USE_CUBLAS )
7962
+ const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
7963
+ const ggml_fp16_t * y = (ggml_fp16_t * ) wdata ;
7964
+
7965
+ float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
7966
+
7967
+ // copy data to device
7968
+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
7969
+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (ggml_fp16_t ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
7970
+
7971
+ // compute
7972
+ CUBLAS_CHECK (
7973
+ cublasGemmEx (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
7974
+ ne01 , ne11 , ne10 ,
7975
+ & alpha , d_X , CUDA_R_16F , ne00 ,
7976
+ d_Y , CUDA_R_16F , ne10 ,
7977
+ & beta , d_D , CUDA_R_32F , ne01 ,
7978
+ CUBLAS_COMPUTE_32F ,
7979
+ CUBLAS_GEMM_DEFAULT ));
7980
+
7981
+ // copy data to host
7982
+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
7983
+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
7984
+ #else
7855
7985
const float * x = wdata ;
7856
7986
const float * y = (float * ) ((char * ) src1 -> data + i02 * nb12 + i03 * nb13 );
7857
7987
@@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7863
7993
1.0f , y , ne10 ,
7864
7994
x , ne00 ,
7865
7995
0.0f , d , ne01 );
7996
+ #endif
7866
7997
}
7867
7998
}
7868
7999
8000
+ #if defined(GGML_USE_CUBLAS )
8001
+ CUDA_CHECK (cudaFree (d_X ));
8002
+ CUDA_CHECK (cudaFree (d_Y ));
8003
+ CUDA_CHECK (cudaFree (d_D ));
8004
+ #endif
7869
8005
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7870
8006
7871
8007
return ;
@@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8017
8153
// nb01 >= nb00 - src0 is not transposed
8018
8154
// compute by src0 rows
8019
8155
8020
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
8156
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
8021
8157
if (ggml_compute_forward_mul_mat_use_blas (src0 , src1 , dst )) {
8022
8158
if (params -> ith != 0 ) {
8023
8159
return ;
@@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
8034
8170
float * const wdata = params -> wdata ;
8035
8171
dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
8036
8172
8173
+ #if defined(GGML_USE_CUBLAS )
8174
+ float * d_X = NULL ;
8175
+ float * d_Y = NULL ;
8176
+ float * d_D = NULL ;
8177
+ const float alpha = 1.0f ;
8178
+ const float beta = 0.0f ;
8179
+ const int x_ne = ne01 * ne10 ;
8180
+ const int y_ne = ne11 * ne10 ;
8181
+ const int d_ne = ne11 * ne01 ;
8182
+
8183
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_X ), sizeof (float ) * x_ne ));
8184
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_Y ), sizeof (float ) * y_ne ));
8185
+ CUDA_CHECK (cudaMalloc ((void * * )(& d_D ), sizeof (float ) * d_ne ));
8186
+ #endif
8187
+
8037
8188
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
8038
8189
for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
8039
8190
{
@@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32(
8049
8200
8050
8201
float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
8051
8202
8203
+ #if defined(GGML_USE_CUBLAS )
8204
+ // copy data to device
8205
+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (float ) * x_ne , cudaMemcpyHostToDevice , cudaStream ));
8206
+ CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
8207
+
8208
+ // compute
8209
+ CUBLAS_CHECK (
8210
+ cublasSgemm (cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
8211
+ ne01 , ne11 , ne10 ,
8212
+ & alpha , d_X , ne00 ,
8213
+ d_Y , ne10 ,
8214
+ & beta , d_D , ne01 ));
8215
+
8216
+ // copy data to host
8217
+ CUDA_CHECK (cudaMemcpyAsync (d , d_D , sizeof (float ) * d_ne , cudaMemcpyDeviceToHost , cudaStream ));
8218
+ CUDA_CHECK (cudaStreamSynchronize (cudaStream ));
8219
+ #else
8052
8220
// zT = y * xT
8053
8221
cblas_sgemm (CblasRowMajor , CblasNoTrans , CblasTrans ,
8054
8222
ne11 , ne01 , ne10 ,
8055
8223
1.0f , y , ne10 ,
8056
8224
x , ne00 ,
8057
8225
0.0f , d , ne01 );
8226
+ #endif
8058
8227
}
8059
8228
}
8060
8229
8230
+ #if defined(GGML_USE_CUBLAS )
8231
+ CUDA_CHECK (cudaFree (d_X ));
8232
+ CUDA_CHECK (cudaFree (d_Y ));
8233
+ CUDA_CHECK (cudaFree (d_D ));
8234
+ #endif
8061
8235
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8062
8236
8063
8237
return ;
@@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10874
11048
size_t cur = 0 ;
10875
11049
10876
11050
if (node -> src0 -> type == GGML_TYPE_F16 && node -> src1 -> type == GGML_TYPE_F32 ) {
10877
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
11051
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
10878
11052
if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
10879
11053
node -> n_tasks = 1 ; // TODO: this actually is doing nothing
10880
11054
// the threads are still spinning
@@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10891
11065
} else if (node -> src0 -> type == GGML_TYPE_F32 && node -> src1 -> type == GGML_TYPE_F32 ) {
10892
11066
cur = 0 ;
10893
11067
} else if (quantize_fns [node -> src0 -> type ].vec_dot_q && node -> src1 -> type == GGML_TYPE_F32 ) {
10894
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
11068
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined( GGML_USE_CUBLAS )
10895
11069
if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
10896
11070
node -> n_tasks = 1 ;
10897
11071
cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
@@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) {
12231
12405
}
12232
12406
12233
12407
int ggml_cpu_has_blas (void ) {
12234
- #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
12408
+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS )
12409
+ return 1 ;
12410
+ #else
12411
+ return 0 ;
12412
+ #endif
12413
+ }
12414
+
12415
+ int ggml_cpu_has_cublas (void ) {
12416
+ #if defined(GGML_USE_CUBLAS )
12235
12417
return 1 ;
12236
12418
#else
12237
12419
return 0 ;
0 commit comments