@@ -720,28 +720,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
720
720
721
721
#if defined(__POWER9_VECTOR__ )
722
722
const vector float v85 = vec_splats (8.5f );
723
+ const vector signed int v15 = vec_splats (15 );
723
724
for (int i = 0 ; i < nb ; i ++ ) {
724
- float amax = 0.0f ; // absolute max
725
+ float max = 0.0f ;
726
+ float min = 0.0f ;
725
727
726
728
vector float srcv [8 ];
727
- vector float asrcv [8 ];
728
- vector float amaxv [8 ];
729
+ vector float maxv [8 ];
730
+ vector float minv [8 ];
729
731
730
732
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = * (vector float * )(x + i * 32 + 4 * l );
731
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vec_abs (srcv [l ]);
733
+ // for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
732
734
733
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
734
- //for (int l = 0; l < 2; l++) amaxv [4*l] = vec_max(amaxv [4*l], amaxv [4*l+2]);
735
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [2 ]);
736
- amaxv [4 ] = vec_max (amaxv [4 ], amaxv [6 ]);
737
- //for (int l = 0; l < 1; l++) amaxv [8*l] = vec_max(amaxv [8*l], amaxv [8*l+4]);
738
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [4 ]);
735
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
736
+ //for (int l = 0; l < 2; l++) maxv [4*l] = vec_max(maxv [4*l], maxv [4*l+2]);
737
+ maxv [0 ] = vec_max (maxv [0 ], maxv [2 ]);
738
+ maxv [4 ] = vec_max (maxv [4 ], maxv [6 ]);
739
+ //for (int l = 0; l < 1; l++) maxv [8*l] = vec_max(maxv [8*l], maxv [8*l+4]);
740
+ maxv [0 ] = vec_max (maxv [0 ], maxv [4 ]);
739
741
740
- amax = MAX (
741
- MAX (vec_extract (amaxv [0 ], 0 ), vec_extract (amaxv [0 ], 1 )),
742
- MAX (vec_extract (amaxv [0 ], 2 ), vec_extract (amaxv [0 ], 3 )));
742
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vec_min (asrcv [2 * l ], asrcv [2 * l + 1 ]);
743
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
744
+ minv [0 ] = vec_min (minv [0 ], minv [2 ]);
745
+ minv [4 ] = vec_min (minv [4 ], minv [6 ]);
746
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
747
+ minv [0 ] = vec_min (minv [0 ], minv [4 ]);
743
748
744
- const float d = amax / ((1 << 3 ) - 1 );
749
+
750
+ max = MAX (
751
+ MAX (vec_extract (maxv [0 ], 0 ), vec_extract (maxv [0 ], 1 )),
752
+ MAX (vec_extract (maxv [0 ], 2 ), vec_extract (maxv [0 ], 3 )));
753
+ min = MIN (
754
+ MIN (vec_extract (minv [0 ], 0 ), vec_extract (minv [0 ], 1 )),
755
+ MIN (vec_extract (minv [0 ], 2 ), vec_extract (minv [0 ], 3 )));
756
+
757
+ const float magnitude = max >= fabsf (min ) ? max : min ;
758
+ const float d = magnitude / -8 ;
745
759
const float id = d ? 1.0 /d : 0.0 ;
746
760
747
761
y [i ].d = d ;
@@ -751,9 +765,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
751
765
for (int l = 0 ; l < 8 ; l ++ ) {
752
766
const vector float vf = vec_madd (srcv [l ], vid , v85 );
753
767
const vector signed int vi = vec_signed (vf );
768
+ const vector signed int vc = vec_min (vi , v15 );
754
769
755
- pb [2 * l + 0 ] = vec_extract (vi , 0 ) | (vec_extract (vi , 1 ) << 4 );
756
- pb [2 * l + 1 ] = vec_extract (vi , 2 ) | (vec_extract (vi , 3 ) << 4 );
770
+ pb [2 * l + 0 ] = vec_extract (vc , 0 ) | (vec_extract (vc , 1 ) << 4 );
771
+ pb [2 * l + 1 ] = vec_extract (vc , 2 ) | (vec_extract (vc , 3 ) << 4 );
757
772
}
758
773
}
759
774
#elif __ARM_NEON
0 commit comments