@@ -5350,6 +5350,8 @@ static void ggml_compute_forward_add_q_f32(
53505350 const int ir0 = dr * ith ;
53515351 const int ir1 = MIN (ir0 + dr , nr );
53525352
5353+ float * wdata = (float * ) params -> wdata + ne00 * ith ;
5354+
53535355 for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
53545356 // src0 indices
53555357 const int i03 = ir /(ne02 * ne01 );
@@ -5372,12 +5374,11 @@ static void ggml_compute_forward_add_q_f32(
53725374 assert (ne00 % 32 == 0 );
53735375
53745376 // unquantize row from src0 to temp buffer
5375- float tmp [ne00 ];
5376- dequantize_row_q (src0_row , tmp , ne00 );
5377+ dequantize_row_q (src0_row , wdata , ne00 );
53775378 // add src1
5378- ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5379+ ggml_vec_acc_f32 (ne00 , wdata , src1_row );
53795380 // quantize row to dst
5380- quantize_row_q (tmp , dst_row , ne00 );
5381+ quantize_row_q (wdata , dst_row , ne00 );
53815382 }
53825383}
53835384
@@ -9566,6 +9567,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95669567 case GGML_OP_ADD :
95679568 {
95689569 node -> n_tasks = n_threads ;
9570+
9571+ size_t cur = 0 ;
9572+
9573+ if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 ) {
9574+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
9575+ }
9576+
9577+ work_size = MAX (work_size , cur );
95699578 } break ;
95709579 case GGML_OP_SUB :
95719580 case GGML_OP_MUL :
0 commit comments