Skip to content

Commit 93356bd

Browse files
authored
ggml : mul mat tweaks (#2372)
* ggml : mul mat wip ggml-ci * ggml : alternative thread distribution for mul_mat ggml-ci * ggml : mul_mat block tiling attempt * ggml : mul_mat threads yield ggml-ci
1 parent 60baff7 commit 93356bd

File tree

1 file changed

+79
-55
lines changed

1 file changed

+79
-55
lines changed

ggml.c

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10731,71 +10731,95 @@ static void ggml_compute_forward_mul_mat(
1073110731
return;
1073210732
}
1073310733

10734-
// parallelize by src0 rows
10735-
const int64_t dr = (ne01 + nth - 1)/nth;
10734+
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10735+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
1073610736

10737-
const int64_t ir10 = dr*ith;
10738-
const int64_t ir11 = MIN(ir10 + dr, ne01);
10737+
const int64_t nr0 = ne01; // src0 rows
10738+
const int64_t nr1 = ne11*ne12*ne13; // src1 rows
1073910739

10740-
// src1 rows
10741-
const int64_t nr1 = ne11*ne12*ne13;
10740+
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
1074210741

10743-
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10744-
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10742+
// distribute the thread work across the inner or outer loop based on which one is larger
1074510743

10746-
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
10747-
const int64_t i13 = (ir1/(ne12*ne11));
10748-
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10749-
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10750-
10751-
const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
10752-
const int64_t i03 = (ir0/(ne02));
10753-
// Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
10754-
// See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
10755-
// GG: this is likely the correct way to broadcast, though need some more thought
10756-
// therefore leaving the comments to remind us for now
10757-
const int64_t i02 = (i12 / (ne12 / ne02));
10758-
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
10759-
// const int64_t i02 = (ir0 - i03*ne02);
10760-
10761-
const int64_t i1 = i11;
10762-
const int64_t i2 = i12;
10763-
const int64_t i3 = i13;
10764-
10765-
const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
10766-
10767-
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10768-
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10769-
// the original src1 data pointer, so we should index using the indices directly
10770-
// TODO: this is a bit of a hack, we should probably have a better way to handle this
10771-
const char * src1_col = (const char *) wdata +
10772-
(src1_cont || src1->type != vec_dot_type
10773-
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10774-
: (i11*nb11 + i12*nb12 + i13*nb13));
10775-
10776-
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10777-
10778-
for (int64_t ir = ir10; ir < ir11; ++ir) {
10779-
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
10780-
}
10744+
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
10745+
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
10746+
10747+
const int64_t ith0 = ith % nth0;
10748+
const int64_t ith1 = ith / nth0;
10749+
10750+
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
10751+
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
10752+
10753+
const int64_t ir010 = dr0*ith0;
10754+
const int64_t ir011 = MIN(ir010 + dr0, nr0);
10755+
10756+
const int64_t ir110 = dr1*ith1;
10757+
const int64_t ir111 = MIN(ir110 + dr1, nr1);
10758+
10759+
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
10760+
10761+
// threads with no work simply yield (not sure if it helps)
10762+
if (ir010 >= ir011 || ir110 >= ir111) {
10763+
sched_yield();
10764+
return;
1078110765
}
1078210766

10783-
//int64_t t1 = ggml_time_us();
10784-
//static int64_t acc = 0;
10785-
//acc += t1 - t0;
10786-
//if (t1 - t0 > 10) {
10787-
// printf("\n");
10788-
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10789-
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10790-
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10767+
assert(ne12 % ne02 == 0);
10768+
assert(ne13 % ne03 == 0);
1079110769

10792-
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10793-
//}
10794-
}
10770+
// broadcast factors
10771+
const int64_t r2 = ne12/ne02;
10772+
const int64_t r3 = ne13/ne03;
1079510773

10774+
// block-tiling attempt
10775+
const int64_t blck_0 = 16;
10776+
const int64_t blck_1 = 16;
1079610777

10797-
// ggml_compute_forward_out_prod
10778+
// attempt to reduce false-sharing (does not seem to make a difference)
10779+
float tmp[16];
10780+
10781+
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
10782+
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
10783+
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
10784+
const int64_t i13 = (ir1/(ne12*ne11));
10785+
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10786+
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10787+
10788+
// broadcast src0 into src1
10789+
const int64_t i03 = i13/r3;
10790+
const int64_t i02 = i12/r2;
10791+
10792+
const int64_t i1 = i11;
10793+
const int64_t i2 = i12;
10794+
const int64_t i3 = i13;
10795+
10796+
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
10797+
10798+
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10799+
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10800+
// the original src1 data pointer, so we should index using the indices directly
10801+
// TODO: this is a bit of a hack, we should probably have a better way to handle this
10802+
const char * src1_col = (const char *) wdata +
10803+
(src1_cont || src1->type != vec_dot_type
10804+
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10805+
: (i11*nb11 + i12*nb12 + i13*nb13));
1079810806

10807+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10808+
10809+
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10810+
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
10811+
//}
10812+
10813+
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
10814+
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
10815+
}
10816+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
10817+
}
10818+
}
10819+
}
10820+
}
10821+
10822+
// ggml_compute_forward_out_prod
1079910823

1080010824
static void ggml_compute_forward_out_prod_f32(
1080110825
const struct ggml_compute_params * params,

0 commit comments

Comments
 (0)