@@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
1034210342 return;
1034310343 }
1034410344
10345- const int nc = src0->ne[0];
10346- const int nr = ggml_nelements(src1);
10345+ GGML_TENSOR_BINARY_OP_LOCALS
10346+
10347+ const int64_t nc = ne00;
10348+ const int64_t nr = ggml_nelements(src1);
10349+
1034710350 const enum ggml_type type = src0->type;
1034810351 ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
1034910352
10350- assert( dst->ne[0] == nc);
10353+ assert(ne0 == nc);
10354+ assert(ne02 == ne11);
10355+ assert(nb00 == ggml_type_size(type));
1035110356 assert(ggml_nrows(dst) == nr);
10352- assert(src0->nb[0] == ggml_type_size(type));
1035310357
10354- for (int i = 0; i < nr; ++i) {
10355- const int r = ((int32_t *) src1->data)[i];
10358+ // TODO: multi-thread
10359+ for (int64_t i = 0; i < nr; ++i) {
10360+ const int64_t r = ((int32_t *) src1->data)[i];
10361+
10362+ const int64_t i02 = i/ne10;
1035610363
1035710364 dequantize_row_q(
10358- (const void *) ((char *) src0->data + r*src0->nb[1] ),
10365+ (const void *) ((char *) src0->data + i02*nb02 + r*nb01 ),
1035910366 (float *) ((char *) dst->data + i*dst->nb[1]), nc);
1036010367 }
1036110368}
@@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
1037110378 return;
1037210379 }
1037310380
10374- const int nc = src0->ne[0];
10375- const int nr = ggml_nelements(src1);
10381+ GGML_TENSOR_BINARY_OP_LOCALS
10382+
10383+ const int64_t nc = ne00;
10384+ const int64_t nr = ggml_nelements(src1);
1037610385
10377- assert( dst->ne[0] == nc);
10386+ assert(ne0 == nc);
10387+ assert(ne02 == ne11);
10388+ assert(nb00 == sizeof(ggml_fp16_t));
1037810389 assert(ggml_nrows(dst) == nr);
10379- assert(src0->nb[0] == sizeof(ggml_fp16_t));
1038010390
10381- for (int i = 0; i < nr; ++i) {
10382- const int r = ((int32_t *) src1->data)[i];
10391+ // TODO: multi-thread
10392+ for (int64_t i = 0; i < nr; ++i) {
10393+ const int64_t r = ((int32_t *) src1->data)[i];
10394+
10395+ const int64_t i02 = i/ne10;
1038310396
1038410397 for (int j = 0; j < nc; ++j) {
10385- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1] ))[j];
10386- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
10398+ ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01 ))[j];
10399+ ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
1038710400 }
1038810401 }
1038910402}
@@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
1039910412 return;
1040010413 }
1040110414
10402- const int nc = src0->ne[0];
10403- const int nr = ggml_nelements(src1);
10415+ GGML_TENSOR_BINARY_OP_LOCALS
10416+
10417+ const int64_t nc = ne00;
10418+ const int64_t nr = ggml_nelements(src1);
1040410419
10405- assert( dst->ne[0] == nc);
10420+ assert(ne0 == nc);
10421+ assert(ne02 == ne11);
10422+ assert(nb00 == sizeof(float));
1040610423 assert(ggml_nrows(dst) == nr);
10407- assert(src0->nb[0] == sizeof(float));
1040810424
10409- for (int i = 0; i < nr; ++i) {
10410- const int r = ((int32_t *) src1->data)[i];
10425+ // TODO: multi-thread
10426+ for (int64_t i = 0; i < nr; ++i) {
10427+ const int64_t r = ((int32_t *) src1->data)[i];
10428+
10429+ const int64_t i02 = i/ne10;
1041110430
1041210431 ggml_vec_cpy_f32(nc,
1041310432 (float *) ((char *) dst->data + i*dst->nb[1]),
10414- (float *) ((char *) src0->data + r*src0->nb[1] ));
10433+ (float *) ((char *) src0->data + i02*nb02 + r*nb01 ));
1041510434 }
1041610435}
1041710436
0 commit comments