Skip to content

Commit cb828c8

Browse files
committed
ggml: parallelize to_float when using blas
* converting fp16 to fp32 or dequantization on signle thread can be bottleneck rather than gemm.
1 parent 862f5e4 commit cb828c8

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

ggml.c

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,8 @@ struct ggml_state {
18831883
static struct ggml_state g_state;
18841884
static atomic_int g_state_barrier = 0;
18851885

1886+
static atomic_int g_blas_pending = 0;
1887+
18861888
// barrier via spin lock
18871889
inline static void ggml_critical_section_start(void) {
18881890
int processing = atomic_fetch_add(&g_state_barrier, 1);
@@ -9835,21 +9837,53 @@ static void ggml_compute_forward_mul_mat(
98359837

98369838
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
98379839
if (ggml_compute_forward_mul_mat_use_blas(dst)) {
9838-
if (params->ith != 0) {
9839-
return;
9840-
}
9840+
const int64_t ne_plane = ne01*ne00;
9841+
const int64_t desired_wsize = ne13*ne12*ne_plane*sizeof(float);
9842+
UNUSED(desired_wsize);
98419843

98429844
if (params->type == GGML_TASK_INIT) {
9845+
if (type != GGML_TYPE_F32) {
9846+
assert(params->wsize >= desired_wsize);
9847+
atomic_store(&g_blas_pending, params->nth);
9848+
}
98439849
return;
98449850
}
98459851

98469852
if (params->type == GGML_TASK_FINALIZE) {
98479853
return;
98489854
}
98499855

9856+
if (type != GGML_TYPE_F32) {
9857+
// parallelize by src0 rows
9858+
for (int64_t i13 = 0; i13 < ne13; i13++) {
9859+
for (int64_t i12 = 0; i12 < ne12; i12++) {
9860+
// broadcast src0 into src1 across 2nd,3rd dimension
9861+
const int64_t i03 = i13/r3;
9862+
const int64_t i02 = i12/r2;
9863+
9864+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9865+
float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
9866+
ggml_to_float_t const to_float = type_traits[type].to_float;
9867+
9868+
for (int64_t i01 = ith; i01 < ne01; i01+=nth) {
9869+
to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
9870+
}
9871+
}
9872+
}
9873+
atomic_fetch_sub(&g_blas_pending, 1);
9874+
while (atomic_load(&g_blas_pending) != 0) {
9875+
// sched_yield();
9876+
}
9877+
}
9878+
9879+
// perform sgemm, parallelization controlled by blas lib
9880+
if (ith != 0) {
9881+
return;
9882+
}
9883+
9884+
const int64_t tgemm0 = ggml_perf_time_us();
98509885
for (int64_t i13 = 0; i13 < ne13; i13++) {
98519886
for (int64_t i12 = 0; i12 < ne12; i12++) {
9852-
// broadcast src0 into src1 across 2nd,3rd dimension
98539887
const int64_t i03 = i13/r3;
98549888
const int64_t i02 = i12/r2;
98559889

@@ -9858,17 +9892,7 @@ static void ggml_compute_forward_mul_mat(
98589892
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
98599893

98609894
if (type != GGML_TYPE_F32) {
9861-
float * const wdata = params->wdata;
9862-
ggml_to_float_t const to_float = type_traits[type].to_float;
9863-
9864-
size_t id = 0;
9865-
for (int64_t i01 = 0; i01 < ne01; ++i01) {
9866-
to_float((const char *) x + i01*nb01, wdata + id, ne00);
9867-
id += ne00;
9868-
}
9869-
9870-
assert(id*sizeof(float) <= params->wsize);
9871-
x = wdata;
9895+
x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
98729896
}
98739897

98749898
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -9878,6 +9902,7 @@ static void ggml_compute_forward_mul_mat(
98789902
0.0f, d, ne01);
98799903
}
98809904
}
9905+
//printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2);
98819906

98829907
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
98839908

0 commit comments

Comments
 (0)