@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607
607
assert (k % QK == 0 );
608
608
609
609
const int nb = k / QK ;
610
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
610
611
611
- float * restrict pm = (float * ) ( y );
612
- float * restrict pd = (float * ) ( pm + nb );
613
- uint8_t * restrict pb = (uint8_t * ) ( pd + nb );
612
+ uint8_t * restrict pd = (( uint8_t * )y + 0 * bs );
613
+ uint8_t * restrict pm = (( uint8_t * )y + 0 * bs + sizeof ( float ) );
614
+ uint8_t * restrict pb = (( uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
614
615
615
616
uint8_t pp [QK /2 ];
616
617
@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627
628
const float d = (max - min ) / ((1 << 4 ) - 1 );
628
629
const float id = d ? 1.0f /d : 0.0f ;
629
630
630
- pm [i ] = min ;
631
- pd [i ] = d ;
631
+ * (float * )pm = min ;
632
+ * (float * )pd = d ;
633
+ pm += bs ;
634
+ pd += bs ;
632
635
633
636
for (int l = 0 ; l < QK ; l += 2 ) {
634
637
const float v0 = (x [i * QK + l + 0 ] - min )* id ;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643
646
pp [l /2 ] = vi0 | (vi1 << 4 );
644
647
}
645
648
646
- memcpy (pb + i * QK /2 , pp , sizeof (pp ));
649
+ memcpy (pb , pp , sizeof (pp ));
650
+ pb += bs ;
647
651
}
648
652
}
649
653
@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687
691
assert (k % QK == 0 );
688
692
689
693
const int nb = k / QK ;
694
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
690
695
691
- const float * restrict pm = (const float * ) ( x );
692
- const float * restrict pd = (const float * ) ( pm + nb );
693
- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
696
+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
697
+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ) );
698
+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
694
699
695
700
for (int i = 0 ; i < nb ; i ++ ) {
696
- const float m = pm [ i ] ;
697
- const float d = pd [ i ] ;
701
+ const float d = * ( const float * ) ( pd + i * bs ) ;
702
+ const float m = * ( const float * ) ( pm + i * bs ) ;
698
703
699
- const uint8_t * restrict pp = pb + i * QK / 2 ;
704
+ const uint8_t * restrict pp = pb + i * bs ;
700
705
701
706
for (int l = 0 ; l < QK ; l += 2 ) {
702
707
const uint8_t vi = pp [l /2 ];
@@ -1584,28 +1589,109 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1584
1589
inline static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict x , const void * restrict y ) {
1585
1590
const int nb = n / QK ;
1586
1591
1587
- const float * restrict pm0 = (const float * ) x ;
1588
- const float * restrict pm1 = (const float * ) y ;
1592
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
1589
1593
1590
- const float * restrict pd0 = (const float * ) (pm0 + nb );
1591
- const float * restrict pd1 = (const float * ) (pm1 + nb );
1594
+ const uint8_t * restrict pd0 = ((const uint8_t * )x + 0 * bs );
1595
+ const uint8_t * restrict pd1 = ((const uint8_t * )y + 0 * bs );
1596
+
1597
+ const uint8_t * restrict pm0 = ((const uint8_t * )x + 0 * bs + sizeof (float ));
1598
+ const uint8_t * restrict pm1 = ((const uint8_t * )y + 0 * bs + sizeof (float ));
1592
1599
1593
- const uint8_t * restrict pb0 = (const uint8_t * ) ( pd0 + nb );
1594
- const uint8_t * restrict pb1 = (const uint8_t * ) ( pd1 + nb );
1600
+ const uint8_t * restrict pb0 = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
1601
+ const uint8_t * restrict pb1 = (( const uint8_t * )y + 0 * bs + 2 * sizeof ( float ) );
1595
1602
1596
1603
float sumf = 0.0 ;
1597
1604
1598
- #if 1
1605
+ #if defined(__AVX2__ )
1606
+ #if QK == 32
1607
+ // Initialize accumulator with zeros
1608
+ __m256 acc = _mm256_setzero_ps ();
1609
+ // Accumulator for constant offsets
1610
+ float acc_offset = 0.0f ;
1611
+
1612
+ // Main loop
1613
+ for (int i = 0 ; i < nb ; ++ i ) {
1614
+ const float * m0 = (const float * ) (pm0 + i * bs );
1615
+ const float * m1 = (const float * ) (pm1 + i * bs );
1616
+
1617
+ const float * d0 = (const float * ) (pd0 + i * bs );
1618
+ const float * d1 = (const float * ) (pd1 + i * bs );
1619
+
1620
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1621
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1622
+
1623
+ const __m256 d0v = _mm256_broadcast_ss ( d0 );
1624
+ const __m256 d1v = _mm256_broadcast_ss ( d1 );
1625
+ const __m256 m0v = _mm256_broadcast_ss ( m0 );
1626
+ const __m256 m1v = _mm256_broadcast_ss ( m1 );
1627
+
1628
+
1629
+ // Compute combined scale for the block
1630
+ const __m256 scale_01 = _mm256_mul_ps ( d0v , d1v );
1631
+
1632
+ // Compute cross scales for the block
1633
+ const __m256 scale_0 = _mm256_mul_ps ( d0v , m1v );
1634
+ const __m256 scale_1 = _mm256_mul_ps ( m0v , d1v );
1635
+ const __m256 cross_scales = _mm256_blend_ps ( scale_0 , scale_1 , 0b10101010 );
1636
+
1637
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1638
+ __m256i bx = bytesFromNibbles ( p0 );
1639
+ __m256i by = bytesFromNibbles ( p1 );
1640
+
1641
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval.
1642
+
1643
+ // Sign-extend first 16 signed bytes into int16_t
1644
+ __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
1645
+ __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
1646
+ // Compute products of int16_t integers, add pairwise
1647
+ __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
1648
+
1649
+ // Sign-extend last 16 signed bytes into int16_t vectors
1650
+ __m256i x16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
1651
+ __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
1652
+ // Accumulate products of int16_t integers
1653
+ i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16_h , y16_h ) );
1654
+
1655
+ // compute sums of unsigned bytes in bx, by in blocks of 8.
1656
+ // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
1657
+ // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
1658
+ // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
1659
+ __m256i xsumi = _mm256_sad_epu8 ( bx , _mm256_setzero_si256 () );
1660
+ __m256i ysumi = _mm256_sad_epu8 ( by , _mm256_setzero_si256 () );
1661
+ __m256i sumsi = _mm256_or_si256 ( xsumi , _mm256_slli_si256 ( ysumi , 4 ) );
1662
+ __m256 sums = _mm256_cvtepi32_ps ( sumsi );
1663
+
1664
+ // Convert int32_t to float
1665
+ __m256 p = _mm256_cvtepi32_ps ( i32 );
1666
+ // Apply the scale, and accumulate
1667
+ // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
1668
+ acc = _mm256_fmadd_ps ( scale_01 , p , acc );
1669
+ acc = _mm256_fmadd_ps ( cross_scales , sums , acc );
1670
+ // acc_offset += m0*m1 (for each entry in the block)
1671
+ acc_offset += (* m0 )* (* m1 );
1672
+ }
1673
+
1674
+ // Return horizontal sum of the acc vector
1675
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
1676
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
1677
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
1678
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
1679
+
1680
+ sumf = _mm_cvtss_f32 ( res ) + acc_offset * QK ;
1681
+ #else
1682
+ #error "not implemented for QK"
1683
+ #endif
1684
+ #else
1599
1685
// scalar
1600
1686
for (int i = 0 ; i < nb ; i ++ ) {
1601
- const float m0 = pm0 [ i ] ;
1602
- const float m1 = pm1 [ i ] ;
1687
+ const float m0 = * ( const float * ) ( pm0 + i * bs ) ;
1688
+ const float m1 = * ( const float * ) ( pm1 + i * bs ) ;
1603
1689
1604
- const float d0 = pd0 [ i ] ;
1605
- const float d1 = pd1 [ i ] ;
1690
+ const float d0 = * ( const float * ) ( pd0 + i * bs ) ;
1691
+ const float d1 = * ( const float * ) ( pd1 + i * bs ) ;
1606
1692
1607
- const uint8_t * restrict p0 = pb0 + i * QK / 2 ;
1608
- const uint8_t * restrict p1 = pb1 + i * QK / 2 ;
1693
+ const uint8_t * restrict p0 = pb0 + i * bs ;
1694
+ const uint8_t * restrict p1 = pb1 + i * bs ;
1609
1695
1610
1696
for (int j = 0 ; j < QK /2 ; j ++ ) {
1611
1697
const uint8_t v0 = p0 [j ];
@@ -1839,16 +1925,17 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
1839
1925
assert (n % QK == 0 );
1840
1926
1841
1927
const int nb = n / QK ;
1928
+ const size_t bs = 2 * sizeof (float ) + QK /2 ;
1842
1929
1843
- const float * restrict pm = (const float * ) ( x );
1844
- const float * restrict pd = (const float * ) ( pm + nb );
1845
- const uint8_t * restrict pb = (const uint8_t * ) ( pd + nb );
1930
+ const uint8_t * restrict pd = (( const uint8_t * )x + 0 * bs );
1931
+ const uint8_t * restrict pm = (( const uint8_t * )x + 0 * bs + sizeof ( float ));
1932
+ const uint8_t * restrict pb = (( const uint8_t * )x + 0 * bs + 2 * sizeof ( float ) );
1846
1933
1847
1934
for (int i = 0 ; i < nb ; i ++ ) {
1848
- const float m = pm [ i ] ;
1849
- const float d = pd [ i ] ;
1935
+ const float d = * ( const float * ) ( pd + i * bs ) ;
1936
+ const float m = * ( const float * ) ( pm + i * bs ) ;
1850
1937
1851
- const uint8_t * restrict pp = pb + i * QK / 2 ;
1938
+ const uint8_t * restrict pp = pb + i * bs ;
1852
1939
1853
1940
for (int l = 0 ; l < QK ; l += 2 ) {
1854
1941
const uint8_t vi = pp [l /2 ];
0 commit comments