Skip to content

Commit c7c8dab

Browse files
committed
ggml : update soft max cpu
1 parent ebd062b commit c7c8dab

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ggml.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32(
1060210602
const int ir0 = dr*ith;
1060310603
const int ir1 = MIN(ir0 + dr, nr);
1060410604

10605-
float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10605+
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
1060610606

1060710607
for (int i1 = ir0; i1 < ir1; i1++) {
1060810608
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
@@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32(
1061110611
// broadcast the mask across rows
1061210612
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
1061310613

10614-
float * wp = wdata;
10615-
for (int i = 0; i < nc; i++) {
10616-
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
10614+
ggml_vec_cpy_f32 (nc, wp, sp);
10615+
ggml_vec_scale_f32(nc, wp, scale);
10616+
if (mp) {
10617+
ggml_vec_acc_f32(nc, wp, mp);
1061710618
}
1061810619

1061910620
#ifndef NDEBUG
@@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1593915940
} break;
1594015941
case GGML_OP_SOFT_MAX:
1594115942
{
15942-
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
15943+
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
1594315944

1594415945
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1594515946
} break;

0 commit comments

Comments
 (0)