39
39
#define MIN (a , b ) ((a) < (b) ? (a) : (b))
40
40
#define MAX (a , b ) ((a) > (b) ? (a) : (b))
41
41
42
+ #define MM256_SET_M128I (a , b ) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
43
+
42
44
//
43
45
// 2-6 bit quantization in super-blocks
44
46
//
@@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1353
1355
const __m256i all_scales = _mm256_cvtepi8_epi16 (scales8 );
1354
1356
const __m128i l_scales = _mm256_extracti128_si256 (all_scales , 0 );
1355
1357
const __m128i h_scales = _mm256_extracti128_si256 (all_scales , 1 );
1356
- const __m256i scales [2 ] = {_mm256_set_m128i (l_scales , l_scales ), _mm256_set_m128i (h_scales , h_scales )};
1358
+ const __m256i scales [2 ] = {MM256_SET_M128I (l_scales , l_scales ), MM256_SET_M128I (h_scales , h_scales )};
1357
1359
1358
1360
__m256i sumi = _mm256_setzero_si256 ();
1359
1361
@@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1421
1423
const __m128i summs_1 = _mm_madd_epi16 (mins_1 , _mm_loadu_si128 ((const __m128i * )& y [i ].bsums [8 ]));
1422
1424
1423
1425
// sumf += -dmin * summs in 32bits*8
1424
- acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& dmin ), _mm256_cvtepi32_ps (_mm256_set_m128i (summs_1 , summs_0 ))), acc );
1426
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& dmin ), _mm256_cvtepi32_ps (MM256_SET_M128I (summs_1 , summs_0 ))), acc );
1425
1427
1426
1428
const __m128i scales_0 = _mm_cvtepi8_epi16 (scales16 );
1427
1429
const __m128i scales_1 = _mm_cvtepi8_epi16 (_mm_unpackhi_epi64 (scales16 , scales16 ));
@@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1493
1495
}
1494
1496
1495
1497
// sumf += dall * isum - dmin * summs in 32bits
1496
- __m256i sumi = _mm256_set_m128i (sumi_1 , sumi_0 );
1498
+ __m256i sumi = MM256_SET_M128I (sumi_1 , sumi_0 );
1497
1499
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& dall ), _mm256_cvtepi32_ps (sumi )), acc );
1498
1500
}
1499
1501
@@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1644
1646
summs += dmin * smin ;
1645
1647
1646
1648
const __m128i q2bits = _mm_loadu_si128 ((const __m128i * )q2 );
1647
- const __m256i q2_0 = _mm256_and_si256 (_mm256_set_m128i (_mm_srli_epi16 (q2bits , 2 ), q2bits ), m3 );
1648
- const __m256i q2_1 = _mm256_and_si256 (_mm256_set_m128i (_mm_srli_epi16 (q2bits , 6 ), _mm_srli_epi16 (q2bits , 4 )), m3 );
1649
+ const __m256i q2_0 = _mm256_and_si256 (MM256_SET_M128I (_mm_srli_epi16 (q2bits , 2 ), q2bits ), m3 );
1650
+ const __m256i q2_1 = _mm256_and_si256 (MM256_SET_M128I (_mm_srli_epi16 (q2bits , 6 ), _mm_srli_epi16 (q2bits , 4 )), m3 );
1649
1651
1650
1652
const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
1651
1653
const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
@@ -1709,10 +1711,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1709
1711
const __m128i p2 = _mm_maddubs_epi16 (q2_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
1710
1712
const __m128i p3 = _mm_maddubs_epi16 (q2_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
1711
1713
1712
- const __m256i p_0 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p0 , p0 )), _mm_cvtepi16_epi32 (p0 ));
1713
- const __m256i p_1 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p1 , p1 )), _mm_cvtepi16_epi32 (p1 ));
1714
- const __m256i p_2 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p2 , p2 )), _mm_cvtepi16_epi32 (p2 ));
1715
- const __m256i p_3 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p3 , p3 )), _mm_cvtepi16_epi32 (p3 ));
1714
+ const __m256i p_0 = MM256_SET_M128I (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p0 , p0 )), _mm_cvtepi16_epi32 (p0 ));
1715
+ const __m256i p_1 = MM256_SET_M128I (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p1 , p1 )), _mm_cvtepi16_epi32 (p1 ));
1716
+ const __m256i p_2 = MM256_SET_M128I (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p2 , p2 )), _mm_cvtepi16_epi32 (p2 ));
1717
+ const __m256i p_3 = MM256_SET_M128I (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p3 , p3 )), _mm_cvtepi16_epi32 (p3 ));
1716
1718
1717
1719
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [0 ]), _mm256_cvtepi32_ps (p_0 )), acc );
1718
1720
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [1 ]), _mm256_cvtepi32_ps (p_1 )), acc );
@@ -1917,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1917
1919
const __m256i all_scales = _mm256_cvtepi8_epi16 (scales128 );
1918
1920
const __m128i l_scales = _mm256_extracti128_si256 (all_scales , 0 );
1919
1921
const __m128i h_scales = _mm256_extracti128_si256 (all_scales , 1 );
1920
- const __m256i scales [2 ] = {_mm256_set_m128i (l_scales , l_scales ), _mm256_set_m128i (h_scales , h_scales )};
1922
+ const __m256i scales [2 ] = {MM256_SET_M128I (l_scales , l_scales ), MM256_SET_M128I (h_scales , h_scales )};
1921
1923
1922
1924
// high bit
1923
1925
const __m256i hbits = _mm256_loadu_si256 ((const __m256i * )x [i ].hmask );
@@ -2128,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2128
2130
}
2129
2131
2130
2132
// multiply with block scale and accumulate
2131
- __m256i sumi = _mm256_set_m128i (sumi_1 , sumi_0 );
2133
+ __m256i sumi = MM256_SET_M128I (sumi_1 , sumi_0 );
2132
2134
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (sumi )), acc );
2133
2135
2134
2136
}
@@ -2303,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2303
2305
aux16 [0 ] = a & 0x0f0f ;
2304
2306
aux16 [1 ] = (a >> 4 ) & 0x0f0f ;
2305
2307
2306
- const __m256i scale_0 = _mm256_set_m128i (_mm_set1_epi16 (aux8 [2 ] - 8 ), _mm_set1_epi16 (aux8 [0 ] - 8 ));
2307
- const __m256i scale_1 = _mm256_set_m128i (_mm_set1_epi16 (aux8 [3 ] - 8 ), _mm_set1_epi16 (aux8 [1 ] - 8 ));
2308
+ const __m256i scale_0 = MM256_SET_M128I (_mm_set1_epi16 (aux8 [2 ] - 8 ), _mm_set1_epi16 (aux8 [0 ] - 8 ));
2309
+ const __m256i scale_1 = MM256_SET_M128I (_mm_set1_epi16 (aux8 [3 ] - 8 ), _mm_set1_epi16 (aux8 [1 ] - 8 ));
2308
2310
2309
2311
memcpy (& aux64 , x [i ].hmask , 8 );
2310
2312
2311
2313
const __m128i haux = _mm_set_epi64x (aux64 >> 1 , aux64 >> 0 );
2312
- __m256i q3h_0 = _mm256_set_m128i (_mm_srli_epi16 (haux , 2 ), haux );
2314
+ __m256i q3h_0 = MM256_SET_M128I (_mm_srli_epi16 (haux , 2 ), haux );
2313
2315
__m256i q3h_1 = _mm256_srli_epi16 (q3h_0 , 4 );
2314
2316
q3h_0 = _mm256_slli_epi16 (_mm256_andnot_si256 (q3h_0 , m1 ), 2 );
2315
2317
q3h_1 = _mm256_slli_epi16 (_mm256_andnot_si256 (q3h_1 , m1 ), 2 );
@@ -2318,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2318
2320
const __m128i q3bits = _mm_loadu_si128 ((const __m128i * )q3 );
2319
2321
2320
2322
// prepare low and high bits
2321
- const __m256i q3aux = _mm256_set_m128i (_mm_srli_epi16 (q3bits , 2 ), q3bits );
2323
+ const __m256i q3aux = MM256_SET_M128I (_mm_srli_epi16 (q3bits , 2 ), q3bits );
2322
2324
const __m256i q3l_0 = _mm256_and_si256 (q3aux , m3 );
2323
2325
const __m256i q3l_1 = _mm256_and_si256 (_mm256_srli_epi16 (q3aux , 4 ), m3 );
2324
2326
@@ -2429,7 +2431,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2429
2431
2430
2432
p16_0 = _mm_add_epi32 (p16_0 , p16_2 );
2431
2433
p16_1 = _mm_add_epi32 (p16_1 , p16_3 );
2432
- __m256i p16 = _mm256_set_m128i (p16_1 , p16_0 );
2434
+ __m256i p16 = MM256_SET_M128I (p16_1 , p16_0 );
2433
2435
2434
2436
// multiply with block scale and accumulate
2435
2437
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (p16 )), acc );
@@ -2620,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2620
2622
acc_m = _mm_fmadd_ps (_mm_set1_ps (dmin ), _mm_cvtepi32_ps (prod ), acc_m );
2621
2623
2622
2624
const __m128i sc128 = _mm256_extracti128_si256 (mins_and_scales , 0 );
2623
- const __m256i scales = _mm256_set_m128i (sc128 , sc128 );
2625
+ const __m256i scales = MM256_SET_M128I (sc128 , sc128 );
2624
2626
2625
2627
__m256i sumi = _mm256_setzero_si256 ();
2626
2628
@@ -2727,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2727
2729
}
2728
2730
2729
2731
__m256 vd = _mm256_set1_ps (d );
2730
- __m256i sumi = _mm256_set_m128i (sumi_1 , sumi_0 );
2732
+ __m256i sumi = MM256_SET_M128I (sumi_1 , sumi_0 );
2731
2733
acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (sumi )), acc );
2732
2734
2733
2735
}
@@ -2968,11 +2970,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2968
2970
2969
2971
const __m128i p32_0 = _mm_madd_epi16 (_mm_set1_epi16 (scales [0 ]), p16_0 );
2970
2972
const __m128i p32_1 = _mm_madd_epi16 (_mm_set1_epi16 (scales [0 ]), p16_1 );
2971
- acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (_mm256_set_m128i (p32_1 , p32_0 ))), acc );
2973
+ acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (MM256_SET_M128I (p32_1 , p32_0 ))), acc );
2972
2974
2973
2975
const __m128i p32_2 = _mm_madd_epi16 (_mm_set1_epi16 (scales [1 ]), p16_2 );
2974
2976
const __m128i p32_3 = _mm_madd_epi16 (_mm_set1_epi16 (scales [1 ]), p16_3 );
2975
- acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (_mm256_set_m128i (p32_3 , p32_2 ))), acc );
2977
+ acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (MM256_SET_M128I (p32_3 , p32_2 ))), acc );
2976
2978
2977
2979
}
2978
2980
@@ -3160,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3160
3162
summs += dmin * _mm_extract_epi32 (hsum , 0 );
3161
3163
3162
3164
const __m128i sc128 = _mm256_extracti128_si256 (mins_and_scales , 0 );
3163
- const __m256i scales = _mm256_set_m128i (sc128 , sc128 );
3165
+ const __m256i scales = MM256_SET_M128I (sc128 , sc128 );
3164
3166
3165
3167
const __m256i hbits = _mm256_loadu_si256 ((const __m256i * )x [i ].qh );
3166
3168
__m256i hmask = mone ;
@@ -3299,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3299
3301
}
3300
3302
3301
3303
__m256 vd = _mm256_set1_ps (d );
3302
- __m256i sumi = _mm256_set_m128i (sumi_1 , sumi_0 );
3304
+ __m256i sumi = MM256_SET_M128I (sumi_1 , sumi_0 );
3303
3305
acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (sumi )), acc );
3304
3306
3305
3307
}
@@ -3462,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3462
3464
3463
3465
const __m256i q5bits = _mm256_loadu_si256 ((const __m256i * )q5 );
3464
3466
3465
- const __m256i scale_l = _mm256_set_m128i (_mm_set1_epi16 (x [i ].scales [1 ]), _mm_set1_epi16 (x [i ].scales [0 ]));
3466
- const __m256i scale_h = _mm256_set_m128i (_mm_set1_epi16 (x [i ].scales [3 ]), _mm_set1_epi16 (x [i ].scales [2 ]));
3467
+ const __m256i scale_l = MM256_SET_M128I (_mm_set1_epi16 (x [i ].scales [1 ]), _mm_set1_epi16 (x [i ].scales [0 ]));
3468
+ const __m256i scale_h = MM256_SET_M128I (_mm_set1_epi16 (x [i ].scales [3 ]), _mm_set1_epi16 (x [i ].scales [2 ]));
3467
3469
3468
3470
int64_t aux64 ;
3469
3471
memcpy (& aux64 , x [i ].qh , 8 );
3470
3472
const __m128i haux128 = _mm_set_epi64x (aux64 >> 1 , aux64 );
3471
- const __m256i haux256 = _mm256_set_m128i (_mm_srli_epi16 (haux128 , 2 ), haux128 );
3473
+ const __m256i haux256 = MM256_SET_M128I (_mm_srli_epi16 (haux128 , 2 ), haux128 );
3472
3474
3473
3475
const __m256i q5h_0 = _mm256_slli_epi16 (_mm256_andnot_si256 (haux256 , mone ), 4 );
3474
3476
const __m256i q5h_1 = _mm256_slli_epi16 (_mm256_andnot_si256 (_mm256_srli_epi16 (haux256 , 4 ), mone ), 4 );
@@ -3543,7 +3545,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3543
3545
const __m128i dot_0 = _mm_sub_epi32 (_mm_add_epi32 (p16_0 , p16_2 ), _mm_add_epi32 (s16_0 , s16_2 ));
3544
3546
const __m128i dot_1 = _mm_sub_epi32 (_mm_add_epi32 (p16_1 , p16_3 ), _mm_add_epi32 (s16_1 , s16_3 ));
3545
3547
3546
- acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d ), _mm256_cvtepi32_ps (_mm256_set_m128i (dot_1 , dot_0 ))), acc );
3548
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d ), _mm256_cvtepi32_ps (MM256_SET_M128I (dot_1 , dot_0 ))), acc );
3547
3549
3548
3550
}
3549
3551
@@ -3925,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3925
3927
3926
3928
}
3927
3929
3928
- __m256i sumi = _mm256_set_m128i (sumi_1 , sumi_0 );
3930
+ __m256i sumi = MM256_SET_M128I (sumi_1 , sumi_0 );
3929
3931
acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (sumi )), acc );
3930
3932
}
3931
3933
@@ -4083,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4083
4085
const __m256i q4bits1 = _mm256_loadu_si256 ((const __m256i * )q4 );
4084
4086
const __m128i q4bitsH = _mm_loadu_si128 ((const __m128i * )qh );
4085
4087
4086
- const __m256i q4h_0 = _mm256_slli_epi16 (_mm256_and_si256 (_mm256_set_m128i (_mm_srli_epi16 (q4bitsH , 2 ), q4bitsH ), m2 ), 4 );
4087
- const __m256i q4h_1 = _mm256_slli_epi16 (_mm256_and_si256 (_mm256_set_m128i (_mm_srli_epi16 (q4bitsH , 6 ), _mm_srli_epi16 (q4bitsH , 4 )), m2 ), 4 );
4088
+ const __m256i q4h_0 = _mm256_slli_epi16 (_mm256_and_si256 (MM256_SET_M128I (_mm_srli_epi16 (q4bitsH , 2 ), q4bitsH ), m2 ), 4 );
4089
+ const __m256i q4h_1 = _mm256_slli_epi16 (_mm256_and_si256 (MM256_SET_M128I (_mm_srli_epi16 (q4bitsH , 6 ), _mm_srli_epi16 (q4bitsH , 4 )), m2 ), 4 );
4088
4090
4089
4091
const __m256i q4_0 = _mm256_or_si256 (_mm256_and_si256 (q4bits1 , m4 ), q4h_0 );
4090
4092
const __m256i q4_1 = _mm256_or_si256 (_mm256_and_si256 (_mm256_srli_epi16 (q4bits1 , 4 ), m4 ), q4h_1 );
@@ -4177,7 +4179,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4177
4179
sumi_0 = _mm_add_epi32 (sumi_0 , _mm_add_epi32 (p16_0 , p16_2 ));
4178
4180
sumi_1 = _mm_add_epi32 (sumi_1 , _mm_add_epi32 (p16_1 , p16_3 ));
4179
4181
4180
- acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (_mm256_set_m128i (sumi_1 , sumi_0 ))), acc );
4182
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (MM256_SET_M128I (sumi_1 , sumi_0 ))), acc );
4181
4183
}
4182
4184
4183
4185
* s = hsum_float_8 (acc );
0 commit comments