2222#define NETWORK_32BIT_AVX2_3 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2
2323#define NETWORK_32BIT_AVX2_4 3 , 2 , 1 , 0 , 7 , 6 , 5 , 4
2424
25- namespace xss {
26- namespace avx2 {
27-
28- // Assumes ymm is bitonic and performs a recursive half cleaner
29- template <typename vtype, typename reg_t = typename vtype::reg_t >
30- X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit (reg_t ymm)
31- {
32-
33- const typename vtype::opmask_t oxAA = _mm256_set_epi32 (
34- 0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 , 0xFFFFFFFF , 0 );
35- const typename vtype::opmask_t oxCC = _mm256_set_epi32 (
36- 0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 , 0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 );
37- const typename vtype::opmask_t oxF0 = _mm256_set_epi32 (
38- 0xFFFFFFFF , 0xFFFFFFFF , 0xFFFFFFFF , 0xFFFFFFFF , 0 , 0 , 0 , 0 );
39-
40- // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
41- ymm = cmp_merge<vtype>(
42- ymm,
43- vtype::permutexvar (_mm256_set_epi32 (NETWORK_32BIT_AVX2_4), ymm),
44- oxF0);
45- // 2) half_cleaner[4]
46- ymm = cmp_merge<vtype>(
47- ymm,
48- vtype::permutexvar (_mm256_set_epi32 (NETWORK_32BIT_AVX2_3), ymm),
49- oxCC);
50- // 3) half_cleaner[1]
51- ymm = cmp_merge<vtype>(
52- ymm, vtype::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(ymm), oxAA);
53- return ymm;
54- }
55-
5625/*
5726 * Assumes ymm is random and performs a full sorting network defined in
5827 * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
@@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
8554struct avx2_32bit_swizzle_ops ;
8655
8756template <>
88- struct ymm_vector <int32_t > {
57+ struct avx2_vector <int32_t > {
8958 using type_t = int32_t ;
9059 using reg_t = __m256i;
9160 using ymmi_t = __m256i;
@@ -231,13 +200,9 @@ struct ymm_vector<int32_t> {
231200 {
232201 _mm256_storeu_si256 ((__m256i *)mem, x);
233202 }
234- static reg_t bitonic_merge (reg_t x)
235- {
236- return bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
237- }
238203 static reg_t sort_vec (reg_t x)
239204 {
240- return sort_ymm_32bit<ymm_vector <type_t >>(x);
205+ return sort_ymm_32bit<avx2_vector <type_t >>(x);
241206 }
242207 static reg_t cast_from (__m256i v){
243208 return v;
@@ -247,7 +212,7 @@ struct ymm_vector<int32_t> {
247212 }
248213};
249214template <>
250- struct ymm_vector <uint32_t > {
215+ struct avx2_vector <uint32_t > {
251216 using type_t = uint32_t ;
252217 using reg_t = __m256i;
253218 using ymmi_t = __m256i;
@@ -378,13 +343,9 @@ struct ymm_vector<uint32_t> {
378343 {
379344 _mm256_storeu_si256 ((__m256i *)mem, x);
380345 }
381- static reg_t bitonic_merge (reg_t x)
382- {
383- return bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
384- }
385346 static reg_t sort_vec (reg_t x)
386347 {
387- return sort_ymm_32bit<ymm_vector <type_t >>(x);
348+ return sort_ymm_32bit<avx2_vector <type_t >>(x);
388349 }
389350 static reg_t cast_from (__m256i v){
390351 return v;
@@ -394,7 +355,7 @@ struct ymm_vector<uint32_t> {
394355 }
395356};
396357template <>
397- struct ymm_vector <float > {
358+ struct avx2_vector <float > {
398359 using type_t = float ;
399360 using reg_t = __m256;
400361 using ymmi_t = __m256i;
@@ -440,6 +401,19 @@ struct ymm_vector<float> {
440401 {
441402 return _mm256_castps_si256 (_mm256_cmp_ps (x, y, _CMP_EQ_OQ));
442403 }
404+ static opmask_t get_partial_loadmask (int size)
405+ {
406+ return (0x0001 << size) - 0x0001 ;
407+ }
408+ template <int type>
409+ static opmask_t fpclass (reg_t x)
410+ {
411+ if constexpr (type == (0x01 | 0x80 )){
412+ return _mm256_castps_si256 (_mm256_cmp_ps (x, x, _CMP_UNORD_Q));
413+ }else {
414+ static_assert (type == (0x01 | 0x80 ), " should not reach here" );
415+ }
416+ }
443417 template <int scale>
444418 static reg_t
445419 mask_i64gather (reg_t src, opmask_t mask, __m256i index, void const *base)
@@ -533,13 +507,9 @@ struct ymm_vector<float> {
533507 {
534508 _mm256_storeu_ps ((float *)mem, x);
535509 }
536- static reg_t bitonic_merge (reg_t x)
537- {
538- return bitonic_merge_ymm_32bit<ymm_vector<type_t >>(x);
539- }
540510 static reg_t sort_vec (reg_t x)
541511 {
542- return sort_ymm_32bit<ymm_vector <type_t >>(x);
512+ return sort_ymm_32bit<avx2_vector <type_t >>(x);
543513 }
544514 static reg_t cast_from (__m256i v){
545515 return _mm256_castsi256_ps (v);
@@ -549,32 +519,6 @@ struct ymm_vector<float> {
549519 }
550520};
551521
552- inline arrsize_t replace_nan_with_inf (float *arr, int64_t arrsize)
553- {
554- arrsize_t nan_count = 0 ;
555- __mmask8 loadmask = 0xFF ;
556- while (arrsize > 0 ) {
557- if (arrsize < 8 ) { loadmask = (0x01 << arrsize) - 0x01 ; }
558- __m256 in_ymm = ymm_vector<float >::maskz_loadu (loadmask, arr);
559- __m256i nanmask = _mm256_castps_si256 (
560- _mm256_cmp_ps (in_ymm, in_ymm, _CMP_NEQ_UQ));
561- nan_count += _mm_popcnt_u32 (avx2_mask_helper32 (nanmask));
562- ymm_vector<float >::mask_storeu (arr, nanmask, YMM_MAX_FLOAT);
563- arr += 8 ;
564- arrsize -= 8 ;
565- }
566- return nan_count;
567- }
568-
569- X86_SIMD_SORT_INLINE void
570- replace_inf_with_nan (float *arr, arrsize_t arrsize, arrsize_t nan_count)
571- {
572- for (arrsize_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
573- arr[ii] = std::nan (" 1" );
574- nan_count -= 1 ;
575- }
576- }
577-
578522struct avx2_32bit_swizzle_ops {
579523 template <typename vtype, int scale>
580524 X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n (typename vtype::reg_t reg){
@@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{
635579 return vtype::cast_from (v1);
636580 }
637581};
638-
639- } // namespace avx2
640- } // namespace xss
641582#endif
0 commit comments