Skip to content

Commit 2c0ed7a

Browse files
committed
multithreaded dequantize in mul_mat when using blas library
1 parent 20fefdf commit 2c0ed7a

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

ggml.c

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9828,21 +9828,45 @@ static void ggml_compute_forward_mul_mat(
98289828

98299829
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
98309830
if (ggml_compute_forward_mul_mat_use_blas(dst)) {
9831-
if (params->ith != 0) {
9832-
return;
9833-
}
9831+
const int64_t ne_plane = ne01*ne00;
9832+
const int64_t desired_wsize = ne13*ne12*ne_plane*sizeof(float);
9833+
UNUSED(desired_wsize);
98349834

98359835
if (params->type == GGML_TASK_INIT) {
9836+
if (type != GGML_TYPE_F32) {
9837+
assert(params->wsize >= desired_wsize);
9838+
// parallelize by src0 rows
9839+
for (int64_t i13 = 0; i13 < ne13; i13++) {
9840+
for (int64_t i12 = 0; i12 < ne12; i12++) {
9841+
// broadcast src0 into src1 across 2nd,3rd dimension
9842+
const int64_t i03 = i13/r3;
9843+
const int64_t i02 = i12/r2;
9844+
9845+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9846+
float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9847+
ggml_to_float_t const to_float = type_traits[type].to_float;
9848+
9849+
for (int64_t i01 = ith; i01 < ne01; i01+=nth) {
9850+
to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
9851+
}
9852+
}
9853+
}
9854+
}
98369855
return;
98379856
}
98389857

98399858
if (params->type == GGML_TASK_FINALIZE) {
98409859
return;
98419860
}
98429861

9862+
// perform sgemm, parallelization controlled by blas lib
9863+
if (ith != 0) {
9864+
return;
9865+
}
9866+
9867+
const int64_t tgemm0 = ggml_perf_time_us();
98439868
for (int64_t i13 = 0; i13 < ne13; i13++) {
98449869
for (int64_t i12 = 0; i12 < ne12; i12++) {
9845-
// broadcast src0 into src1 across 2nd,3rd dimension
98469870
const int64_t i03 = i13/r3;
98479871
const int64_t i02 = i12/r2;
98489872

@@ -9851,17 +9875,7 @@ static void ggml_compute_forward_mul_mat(
98519875
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
98529876

98539877
if (type != GGML_TYPE_F32) {
9854-
float * const wdata = params->wdata;
9855-
ggml_to_float_t const to_float = type_traits[type].to_float;
9856-
9857-
size_t id = 0;
9858-
for (int64_t i01 = 0; i01 < ne01; ++i01) {
9859-
to_float((const char *) x + i01*nb01, wdata + id, ne00);
9860-
id += ne00;
9861-
}
9862-
9863-
assert(id*sizeof(float) <= params->wsize);
9864-
x = wdata;
9878+
x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
98659879
}
98669880

98679881
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -9871,6 +9885,7 @@ static void ggml_compute_forward_mul_mat(
98719885
0.0f, d, ne01);
98729886
}
98739887
}
9888+
//printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2);
98749889

98759890
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
98769891

@@ -16782,7 +16797,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1678216797
if (ggml_compute_forward_mul_mat_use_blas(node)) {
1678316798
if (node->src[0]->type != GGML_TYPE_F32) {
1678416799
// here we need memory just for single 2D matrix from src0
16785-
cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16800+
cur = ggml_type_size(GGML_TYPE_F32)
16801+
* node->src[0]->ne[0]*node->src[0]->ne[1]
16802+
* node->src[1]->ne[2]*node->src[1]->ne[3];
1678616803
}
1678716804
} else
1678816805
#endif

0 commit comments

Comments
 (0)