@@ -9828,21 +9828,45 @@ static void ggml_compute_forward_mul_mat(
9828
9828
9829
9829
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9830
9830
if (ggml_compute_forward_mul_mat_use_blas(dst)) {
9831
- if (params->ith != 0) {
9832
- return ;
9833
- }
9831
+ const int64_t ne_plane = ne01*ne00;
9832
+ const int64_t desired_wsize = ne13*ne12*ne_plane*sizeof(float) ;
9833
+ UNUSED(desired_wsize);
9834
9834
9835
9835
if (params->type == GGML_TASK_INIT) {
9836
+ if (type != GGML_TYPE_F32) {
9837
+ assert(params->wsize >= desired_wsize);
9838
+ // parallelize by src0 rows
9839
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
9840
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
9841
+ // broadcast src0 into src1 across 2nd,3rd dimension
9842
+ const int64_t i03 = i13/r3;
9843
+ const int64_t i02 = i12/r2;
9844
+
9845
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9846
+ float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9847
+ ggml_to_float_t const to_float = type_traits[type].to_float;
9848
+
9849
+ for (int64_t i01 = ith; i01 < ne01; i01+=nth) {
9850
+ to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
9851
+ }
9852
+ }
9853
+ }
9854
+ }
9836
9855
return;
9837
9856
}
9838
9857
9839
9858
if (params->type == GGML_TASK_FINALIZE) {
9840
9859
return;
9841
9860
}
9842
9861
9862
+ // perform sgemm, parallelization controlled by blas lib
9863
+ if (ith != 0) {
9864
+ return;
9865
+ }
9866
+
9867
+ const int64_t tgemm0 = ggml_perf_time_us();
9843
9868
for (int64_t i13 = 0; i13 < ne13; i13++) {
9844
9869
for (int64_t i12 = 0; i12 < ne12; i12++) {
9845
- // broadcast src0 into src1 across 2nd,3rd dimension
9846
9870
const int64_t i03 = i13/r3;
9847
9871
const int64_t i02 = i12/r2;
9848
9872
@@ -9851,17 +9875,7 @@ static void ggml_compute_forward_mul_mat(
9851
9875
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9852
9876
9853
9877
if (type != GGML_TYPE_F32) {
9854
- float * const wdata = params->wdata;
9855
- ggml_to_float_t const to_float = type_traits[type].to_float;
9856
-
9857
- size_t id = 0;
9858
- for (int64_t i01 = 0; i01 < ne01; ++i01) {
9859
- to_float((const char *) x + i01*nb01, wdata + id, ne00);
9860
- id += ne00;
9861
- }
9862
-
9863
- assert(id*sizeof(float) <= params->wsize);
9864
- x = wdata;
9878
+ x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9865
9879
}
9866
9880
9867
9881
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -9871,6 +9885,7 @@ static void ggml_compute_forward_mul_mat(
9871
9885
0.0f, d, ne01);
9872
9886
}
9873
9887
}
9888
+ //printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2);
9874
9889
9875
9890
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
9876
9891
@@ -16782,7 +16797,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
16782
16797
if (ggml_compute_forward_mul_mat_use_blas(node)) {
16783
16798
if (node->src[0]->type != GGML_TYPE_F32) {
16784
16799
// here we need memory just for single 2D matrix from src0
16785
- cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16800
+ cur = ggml_type_size(GGML_TYPE_F32)
16801
+ * node->src[0]->ne[0]*node->src[0]->ne[1]
16802
+ * node->src[1]->ne[2]*node->src[1]->ne[3];
16786
16803
}
16787
16804
} else
16788
16805
#endif
0 commit comments