@@ -536,14 +536,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
536536 device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
537537
538538 float sumf = 0 ;
539- for (int i = tiisg; i < ne00; i += 32 ) {
540- sumf += (float ) x[i] * (float ) y[i];
539+ if (ne00 < 128 ) {
540+ for (int i = tiisg; i < ne00; i += 32 ) {
541+ sumf += (float ) x[i] * (float ) y[i];
542+ }
543+ float all_sum = simd_sum (sumf);
544+ if (tiisg == 0 ) {
545+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546+ }
547+ } else {
548+ device const half4 * x4 = (device const half4 *) x;
549+ device const float4 * y4 = (device const float4 *) y;
550+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
551+ for (int k = 0 ; k < 4 ; ++k) sumf += (float )x4[i][k] * y4[i][k];
552+ }
553+ float all_sum = simd_sum (sumf);
554+ if (tiisg == 0 ) {
555+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) all_sum += (float ) x[i] * y[i];
556+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
557+ }
541558 }
542559
543- float all_sum = simd_sum (sumf);
544- if (tiisg == 0 ) {
545- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
546- }
547560}
548561
549562#define N_F16_F32 4
@@ -570,29 +583,54 @@ kernel void kernel_mul_mat_f16_f32(
570583 uint tiisg[[thread_index_in_simdgroup]]) {
571584
572585 const int64_t r0 = tgpig.x ;
573- const int64_t rb = N_F16_F32* tgpig.y ;
586+ const int64_t rb = tgpig.y *N_F16_F32 ;
574587 const int64_t im = tgpig.z ;
575588
576589 device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
577590
578- for (int row = 0 ; row < N_F16_F32; ++row) {
579- int r1 = rb + row;
580- if (r1 >= ne11) {
581- break ;
582- }
591+ if (ne00 < 128 ) {
592+ for (int row = 0 ; row < N_F16_F32; ++row) {
593+ int r1 = rb + row;
594+ if (r1 >= ne11) {
595+ break ;
596+ }
583597
584- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
598+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
585599
586- float sumf = 0 ;
587- for (int i = tiisg; i < ne00; i += 32 ) {
588- sumf += (float ) x[i] * (float ) y[i];
600+ float sumf = 0 ;
601+ for (int i = tiisg; i < ne00; i += 32 ) {
602+ sumf += (float ) x[i] * (float ) y[i];
603+ }
604+
605+ float all_sum = simd_sum (sumf);
606+ if (tiisg == 0 ) {
607+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
608+ }
589609 }
610+ } else {
611+ device const half4 * x4 = (device const half4 *)x;
612+ for (int row = 0 ; row < N_F16_F32; ++row) {
613+ int r1 = rb + row;
614+ if (r1 >= ne11) {
615+ break ;
616+ }
590617
591- float all_sum = simd_sum (sumf);
592- if (tiisg == 0 ) {
593- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
618+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
619+ device const float4 * y4 = (device const float4 *) y;
620+
621+ float sumf = 0 ;
622+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
623+ for (int k = 0 ; k < 4 ; ++k) sumf += (float ) x4[i][k] * y4[i][k];
624+ }
625+
626+ float all_sum = simd_sum (sumf);
627+ if (tiisg == 0 ) {
628+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) all_sum += (float ) x[i] * y[i];
629+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
630+ }
594631 }
595632 }
633+
596634}
597635
598636kernel void kernel_alibi_f32 (
0 commit comments