@@ -1743,33 +1743,187 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1743
1743
}
1744
1744
1745
1745
#if __AVX512F__ && QK == 32
1746
- static inline __m512 dot_q4_0_oneblock_avx512 (
1746
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512 ( const __m512i blocks ) {
1747
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
1748
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1749
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1750
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1751
+ // | :. =_ () [] <> () Zz Yy|
1752
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1753
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1754
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1755
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
1756
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1757
+ //
1758
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
1759
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
1760
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
1761
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
1762
+ #ifdef __AVX512VBMI__
1763
+ const __m512i byte_perm = _mm512_set_epi8 (
1764
+ 39 , 38 , 39 , 38 , 37 , 36 , 37 , 36 , 35 , 34 , 35 , 34 , 33 , 32 , 33 , 32 ,
1765
+ 31 , 30 , 31 , 30 , 29 , 28 , 29 , 28 , 27 , 26 , 27 , 26 , 25 , 24 , 25 , 24 ,
1766
+ 19 , 18 , 19 , 18 , 17 , 16 , 17 , 16 , 15 , 14 , 15 , 14 , 13 , 12 , 13 , 12 ,
1767
+ 11 , 10 , 11 , 10 , 9 , 8 , 9 , 8 , 7 , 6 , 7 , 6 , 5 , 4 , 5 , 4
1768
+ );
1769
+ const __m512i permuted = _mm512_permutexvar_epi8 ( byte_perm , blocks );
1770
+ // After applying VPERMB, `permuted` looks like this:
1771
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1772
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1773
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1774
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
1775
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1776
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1777
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1778
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
1779
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1780
+ #else
1781
+ const __m512i word_perm = _mm512_set_epi16 (
1782
+ 19 , 19 , 18 , 18 , 17 , 17 , 16 , 16 , 15 , 15 , 14 , 14 , 13 , 13 , 12 , 12 ,
1783
+ 9 , 9 , 8 , 8 , 7 , 7 , 6 , 6 , 5 , 5 , 4 , 4 , 3 , 3 , 2 , 2
1784
+ );
1785
+ const __m512i permuted = _mm512_permutexvar_epi16 ( word_perm , blocks );
1786
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
1787
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
1788
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
1789
+ #endif
1790
+
1791
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
1792
+ const __mmask32 shift_mask = 0xaaaaaaaa ;
1793
+ const __m512i shifted = _mm512_mask_srai_epi16 ( permuted , shift_mask , permuted , 4 );
1794
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
1795
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1796
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
1797
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1798
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
1799
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1800
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1801
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1802
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
1803
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1804
+
1805
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
1806
+ const __m512i low_nibble_mask = _mm512_set1_epi8 ( 0xf );
1807
+ return _mm512_and_si512 ( low_nibble_mask , shifted );
1808
+ // The final result looks like this:
1809
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1810
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1811
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1812
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
1813
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1814
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1815
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1816
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
1817
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1818
+ }
1819
+
1820
+ static inline __m512 dot_q4_0_twoblocks_avx512 (
1747
1821
__m512 acc ,
1748
1822
const block_q4_0 * restrict x ,
1749
1823
const block_q4_0 * restrict y ,
1750
1824
int i
1751
1825
) {
1752
- // Compute combined scale for the block
1753
- __m512 d = _mm512_set1_ps ( x [i ].d * y [i ].d );
1754
-
1755
- __m256i bx = bytesFromNibbles ( x [i ].qs );
1756
- __m256i by = bytesFromNibbles ( y [i ].qs );
1757
-
1758
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1759
- const __m256i off = _mm256_set1_epi8 ( 8 );
1760
- bx = _mm256_sub_epi8 ( bx , off );
1761
- by = _mm256_sub_epi8 ( by , off );
1762
-
1763
- // Sign-extend 16 signed bytes into int16_t
1764
- __m512i x32 = _mm512_cvtepi8_epi16 ( bx );
1765
- __m512i y32 = _mm512_cvtepi8_epi16 ( by );
1766
- // Compute products of int16_t integers, add pairwise
1767
- __m512i i64 = _mm512_madd_epi16 ( x32 , y32 );
1826
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
1827
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
1828
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
1829
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
1830
+ const __mmask8 load_mask = 0x1f ;
1831
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64 ( load_mask , & x [i ] );
1832
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64 ( load_mask , & y [i ] );
1833
+
1834
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
1835
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1836
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
1837
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1838
+ // blocks_0_float
1839
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1840
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
1841
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1842
+ // blocks_1_float
1843
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1844
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
1845
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1846
+ const __m512 blocks_0_float = _mm512_castsi512_ps ( blocks_0 );
1847
+ const __m512 blocks_1_float = _mm512_castsi512_ps ( blocks_1 );
1848
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
1849
+ // random data, which might very well underflow. At least on Intel, this leads
1850
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
1851
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
1852
+ // (and ggml can't assume that you do)...
1853
+ const __mmask16 scale_mul_mask = 0x21 ;
1854
+ #ifdef __clang__
1855
+ // ...however, clang decides to optimize the multiplication mask away:
1856
+ // https://godbolt.org/z/P8PqdsfvW
1857
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
1858
+ __m512i scales ;
1859
+ __asm__(
1860
+ "vmulps %1, %2, %0%{%3%}"
1861
+ : "=v" ( scales )
1862
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
1863
+ );
1864
+ #else
1865
+ const __m512 scales = _mm512_maskz_mul_ps ( scale_mul_mask , blocks_0_float , blocks_1_float );
1866
+ #endif
1867
+ const __m512i scale_perm = _mm512_set_epi32 (
1868
+ 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
1869
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
1870
+ );
1871
+ const __m512 permuted_scales = _mm512_permutexvar_ps ( scale_perm , scales );
1872
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
1873
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1874
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
1875
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1876
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
1877
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1878
+
1879
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512 ( blocks_0 );
1880
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512 ( blocks_1 );
1881
+
1882
+ // Now we want to compute dot products of 4-element byte vectors and store them in
1883
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
1884
+ // +----+----+----+----+
1885
+ // ... | 03 | 02 | 01 | 00 |
1886
+ // +----+----+----+----+
1887
+ // bytes_0
1888
+ // +----+----+----+----+
1889
+ // ... | D | C | B | A |
1890
+ // +----+----+----+----+
1891
+ // bytes_1
1892
+ // +----+----+----+----+
1893
+ // ... | H | G | F | E |
1894
+ // +----+----+----+----+
1895
+ // final_res_int
1896
+ // +----+----+----+----+
1897
+ // ... | A*E+B*F+C*G+D*H |
1898
+ // +----+----+----+----+
1899
+ const __m512i plus_8 = _mm512_set1_epi8 ( 8 );
1900
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8 ( bytes_1 , plus_8 );
1901
+
1902
+ #ifdef __AVX512VNNI__
1903
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
1904
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
1905
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
1906
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
1907
+ // which means we only need 2 instructions.
1908
+ const __m512i dot_init = _mm512_set1_epi32 ( 4 * 64 );
1909
+ const __m512i minus_8 = _mm512_set1_epi8 ( -8 );
1910
+ const __m512i prod_0 = _mm512_dpbusds_epi32 ( dot_init , bytes_1 , minus_8 );
1911
+ const __m512i final_res_int = _mm512_dpbusds_epi32 ( prod_0 , bytes_0 , bytes_1_minus_8 );
1912
+ #else
1913
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
1914
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
1915
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
1916
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
1917
+ const __m512i one = _mm512_set1_epi16 ( 1 );
1918
+ const __m512i prod_0 = _mm512_maddubs_epi16 ( bytes_0 , bytes_1_minus_8 );
1919
+ const __m512i prod_1 = _mm512_maddubs_epi16 ( plus_8 , bytes_1_minus_8 );
1920
+ const __m512i diff = _mm512_sub_epi16 ( prod_0 , prod_1 );
1921
+ const __m512i final_res_int = _mm512_madd_epi16 ( diff , one );
1922
+ #endif
1768
1923
1769
- // Convert int32_t to float
1770
- __m512 p = _mm512_cvtepi32_ps ( i64 );
1771
- // Apply the scale, and accumulate
1772
- return _mm512_fmadd_ps ( d , p , acc );
1924
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
1925
+ const __m512 final_res_float = _mm512_cvtepi32_ps ( final_res_int );
1926
+ return _mm512_fmadd_ps ( permuted_scales , final_res_float , acc );
1773
1927
}
1774
1928
#endif
1775
1929
@@ -1919,25 +2073,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1919
2073
__m512 acc0 = _mm512_setzero_ps ();
1920
2074
__m512 acc1 = _mm512_setzero_ps ();
1921
2075
1922
- const int superblock_size = 8 ;
2076
+ const int superblock_size = 16 ;
2077
+
1923
2078
const int superblock_count = nb / superblock_size ;
1924
2079
1925
2080
for (int superblock_ix = 0 ; superblock_ix < superblock_count ; superblock_ix += 1 ) {
1926
2081
int i = superblock_ix * superblock_size ;
1927
2082
1928
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 0 );
1929
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 1 );
1930
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 2 );
1931
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 3 );
1932
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 4 );
1933
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 5 );
1934
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 6 );
1935
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 7 );
2083
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 0 );
2084
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 2 );
2085
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 4 );
2086
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 6 );
2087
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 8 );
2088
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 10 );
2089
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 12 );
2090
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 14 );
1936
2091
}
1937
2092
1938
2093
// Remainders
1939
- for (int i = superblock_count * superblock_size ; i < nb ; ++ i ) {
1940
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i );
2094
+ for (int i = superblock_count * superblock_size ; i < nb ; i += 2 ) {
2095
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i );
1941
2096
}
1942
2097
1943
2098
// Horizontal sum of all lanes of the accumulator
@@ -10817,6 +10972,22 @@ int ggml_cpu_has_avx512(void) {
10817
10972
#endif
10818
10973
}
10819
10974
10975
+ int ggml_cpu_has_avx512_vbmi (void ) {
10976
+ #if defined(__AVX512VBMI__ )
10977
+ return 1 ;
10978
+ #else
10979
+ return 0 ;
10980
+ #endif
10981
+ }
10982
+
10983
+ int ggml_cpu_has_avx512_vnni (void ) {
10984
+ #if defined(__AVX512VNNI__ )
10985
+ return 1 ;
10986
+ #else
10987
+ return 0 ;
10988
+ #endif
10989
+ }
10990
+
10820
10991
int ggml_cpu_has_fma (void ) {
10821
10992
#if defined(__FMA__ )
10822
10993
return 1 ;
0 commit comments