@@ -1883,6 +1883,8 @@ struct ggml_state {
1883
1883
static struct ggml_state g_state;
1884
1884
static atomic_int g_state_barrier = 0;
1885
1885
1886
+ static atomic_int g_blas_pending = 0;
1887
+
1886
1888
// barrier via spin lock
1887
1889
inline static void ggml_critical_section_start(void) {
1888
1890
int processing = atomic_fetch_add(&g_state_barrier, 1);
@@ -9835,21 +9837,53 @@ static void ggml_compute_forward_mul_mat(
9835
9837
9836
9838
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9837
9839
if (ggml_compute_forward_mul_mat_use_blas(dst)) {
9838
- if (params->ith != 0) {
9839
- return ;
9840
- }
9840
+ const int64_t ne_plane = ne01*ne00;
9841
+ const int64_t desired_wsize = ne13*ne12*ne_plane*sizeof(float) ;
9842
+ UNUSED(desired_wsize);
9841
9843
9842
9844
if (params->type == GGML_TASK_INIT) {
9845
+ if (type != GGML_TYPE_F32) {
9846
+ assert(params->wsize >= desired_wsize);
9847
+ atomic_store(&g_blas_pending, params->nth);
9848
+ }
9843
9849
return;
9844
9850
}
9845
9851
9846
9852
if (params->type == GGML_TASK_FINALIZE) {
9847
9853
return;
9848
9854
}
9849
9855
9856
+ if (type != GGML_TYPE_F32) {
9857
+ // parallelize by src0 rows
9858
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
9859
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
9860
+ // broadcast src0 into src1 across 2nd,3rd dimension
9861
+ const int64_t i03 = i13/r3;
9862
+ const int64_t i02 = i12/r2;
9863
+
9864
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9865
+ float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9866
+ ggml_to_float_t const to_float = type_traits[type].to_float;
9867
+
9868
+ for (int64_t i01 = ith; i01 < ne01; i01+=nth) {
9869
+ to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
9870
+ }
9871
+ }
9872
+ }
9873
+ atomic_fetch_sub(&g_blas_pending, 1);
9874
+ while (atomic_load(&g_blas_pending) != 0) {
9875
+ // sched_yield();
9876
+ }
9877
+ }
9878
+
9879
+ // perform sgemm, parallelization controlled by blas lib
9880
+ if (ith != 0) {
9881
+ return;
9882
+ }
9883
+
9884
+ const int64_t tgemm0 = ggml_perf_time_us();
9850
9885
for (int64_t i13 = 0; i13 < ne13; i13++) {
9851
9886
for (int64_t i12 = 0; i12 < ne12; i12++) {
9852
- // broadcast src0 into src1 across 2nd,3rd dimension
9853
9887
const int64_t i03 = i13/r3;
9854
9888
const int64_t i02 = i12/r2;
9855
9889
@@ -9858,17 +9892,7 @@ static void ggml_compute_forward_mul_mat(
9858
9892
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9859
9893
9860
9894
if (type != GGML_TYPE_F32) {
9861
- float * const wdata = params->wdata;
9862
- ggml_to_float_t const to_float = type_traits[type].to_float;
9863
-
9864
- size_t id = 0;
9865
- for (int64_t i01 = 0; i01 < ne01; ++i01) {
9866
- to_float((const char *) x + i01*nb01, wdata + id, ne00);
9867
- id += ne00;
9868
- }
9869
-
9870
- assert(id*sizeof(float) <= params->wsize);
9871
- x = wdata;
9895
+ x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9872
9896
}
9873
9897
9874
9898
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -9878,6 +9902,7 @@ static void ggml_compute_forward_mul_mat(
9878
9902
0.0f, d, ne01);
9879
9903
}
9880
9904
}
9905
+ //printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2);
9881
9906
9882
9907
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
9883
9908
0 commit comments