@@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32(
10602
10602
const int ir0 = dr*ith;
10603
10603
const int ir1 = MIN(ir0 + dr, nr);
10604
10604
10605
- float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10605
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10606
10606
10607
10607
for (int i1 = ir0; i1 < ir1; i1++) {
10608
10608
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
@@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32(
10611
10611
// broadcast the mask across rows
10612
10612
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10613
10613
10614
- float * wp = wdata;
10615
- for (int i = 0; i < nc; i++) {
10616
- wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
10614
+ ggml_vec_cpy_f32 (nc, wp, sp);
10615
+ ggml_vec_scale_f32(nc, wp, scale);
10616
+ if (mp) {
10617
+ ggml_vec_acc_f32(nc, wp, mp);
10617
10618
}
10618
10619
10619
10620
#ifndef NDEBUG
@@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15939
15940
} break;
15940
15941
case GGML_OP_SOFT_MAX:
15941
15942
{
15942
- n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
15943
+ n_tasks = MIN(MIN(4, n_threads) , ggml_nrows(node->src[0]));
15943
15944
15944
15945
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15945
15946
} break;
0 commit comments