diff --git a/benchmarks/bench-qsort.hpp b/benchmarks/bench-qsort.hpp index f95b05ba..a6e1c794 100644 --- a/benchmarks/bench-qsort.hpp +++ b/benchmarks/bench-qsort.hpp @@ -36,9 +36,49 @@ static void simdsort(benchmark::State &state, Args &&...args) } } +template +static void scalar_revsort(benchmark::State &state, Args &&...args) +{ + // Get args + auto args_tuple = std::make_tuple(std::move(args)...); + size_t arrsize = std::get<0>(args_tuple); + std::string arrtype = std::get<1>(args_tuple); + // set up array + std::vector arr = get_array(arrtype, arrsize); + std::vector arr_bkp = arr; + // benchmark + for (auto _ : state) { + std::sort(arr.rbegin(), arr.rend()); + state.PauseTiming(); + arr = arr_bkp; + state.ResumeTiming(); + } +} + +template +static void simd_revsort(benchmark::State &state, Args &&...args) +{ + // Get args + auto args_tuple = std::make_tuple(std::move(args)...); + size_t arrsize = std::get<0>(args_tuple); + std::string arrtype = std::get<1>(args_tuple); + // set up array + std::vector arr = get_array(arrtype, arrsize); + std::vector arr_bkp = arr; + // benchmark + for (auto _ : state) { + x86simdsort::qsort(arr.data(), arrsize, false, true); + state.PauseTiming(); + arr = arr_bkp; + state.ResumeTiming(); + } +} + #define BENCH_BOTH_QSORT(type) \ BENCH_SORT(simdsort, type) \ - BENCH_SORT(scalarsort, type) + BENCH_SORT(scalarsort, type) \ + BENCH_SORT(simd_revsort, type) \ + BENCH_SORT(scalar_revsort, type) BENCH_BOTH_QSORT(uint64_t) BENCH_BOTH_QSORT(int64_t) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 7700c9f4..345653d9 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -7,19 +7,21 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qsort(arr, arrsize, hasnan); \ + avx2_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qselect(arr, k, arrsize, hasnan); \ + avx2_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_partial_qsort(arr, k, arrsize, hasnan); \ + avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 09caefb5..20095369 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -5,34 +5,50 @@ namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan) + void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(uint16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(uint16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan) + void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(int16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(int16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 70f13daf..dad32b91 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -8,19 +8,26 @@ namespace xss { namespace avx512 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -33,19 +40,26 @@ namespace avx512 { namespace avx2 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -58,19 +72,26 @@ namespace avx2 { namespace scalar { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index a5348106..6afc7287 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -4,8 +4,10 @@ namespace xss { namespace utils { - /* O(1) permute array in place: stolen from - * http://www.davidespataro.it/apply-a-permutation-to-a-vector */ + /* + * O(1) permute array in place: stolen from + * http://www.davidespataro.it/apply-a-permutation-to-a-vector + */ template void apply_permutation_in_place(T *arr, std::vector arg) { @@ -21,40 +23,51 @@ namespace utils { arg[curr] = curr; } } -} // namespace utils - -namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan) + decltype(auto) get_cmp_func(bool hasnan, bool reverse) { + std::function cmp; if (hasnan) { - std::sort(arr, arr + arrsize, compare>()); + if (reverse == true) { cmp = compare>(); } + else { + cmp = compare>(); + } } else { - std::sort(arr, arr + arrsize); + if (reverse == true) { cmp = std::greater(); } + else { + cmp = std::less(); + } } + return cmp; } +} // namespace utils + +namespace scalar { template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan) + void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { - if (hasnan) { - std::nth_element( - arr, arr + k, arr + arrsize, compare>()); - } - else { - std::nth_element(arr, arr + k, arr + arrsize); - } + std::sort(arr, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); } + template - void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { - if (hasnan) { - std::partial_sort( - arr, arr + k, arr + arrsize, compare>()); - } - else { - std::partial_sort(arr, arr + k, arr + arrsize); - } + std::nth_element(arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); + } + template + void + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + { + std::partial_sort(arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); } template std::vector argsort(T *arr, size_t arrsize, bool hasnan) diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 11145e3a..4a1c2a9f 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -7,19 +7,21 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qsort(arr, arrsize, hasnan); \ + avx512_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qselect(arr, k, arrsize, hasnan); \ + avx512_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_partial_qsort(arr, k, arrsize, hasnan); \ + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index e07de36f..b09a8393 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -5,19 +5,36 @@ namespace xss { namespace avx512 { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan) + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + if (descending) { avx512_qsort(arr, size, hasnan); } + else { + avx512_qsort(arr, size, hasnan); + } } template <> - void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + if (descending) { avx512_qselect(arr, k, arrsize, hasnan); } + else { + avx512_qselect(arr, k, arrsize, hasnan); + } } template <> - void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + if (descending) { avx512_partial_qsort(arr, k, arrsize, hasnan); } + else { + avx512_partial_qsort(arr, k, arrsize, hasnan); + } } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8626b185..21c8b34f 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -57,29 +57,32 @@ namespace x86simdsort { #define CAT(a, b) CAT_(a, b) #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ template <> \ - void qsort(TYPE *arr, size_t arrsize, bool hasnan) \ + void qsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool) \ + static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void qselect(TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ - static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool) \ + static void (*internal_partial_qsort##TYPE)( \ + TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void partial_qsort(TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 4dfc6d4b..42d5247f 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -14,17 +14,24 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); +XSS_EXPORT_SYMBOL void +qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template -XSS_EXPORT_SYMBOL void -qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); +XSS_EXPORT_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template -XSS_EXPORT_SYMBOL void -partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); +XSS_EXPORT_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index c93c5c2a..ad4e99fc 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -85,7 +85,11 @@ struct avx2_vector { static reg_t zmm_max() { return _mm256_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm256_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); @@ -251,6 +255,10 @@ struct avx2_vector { { return _mm256_set1_epi32(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); @@ -405,6 +413,10 @@ struct avx2_vector { { return _mm256_set1_ps(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_ps(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 4028655c..c633b4b9 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -67,7 +67,11 @@ struct avx2_vector { static reg_t zmm_max() { return _mm256_set1_epi64x(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm256_set1_epi64x(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); @@ -248,6 +252,10 @@ struct avx2_vector { { return _mm256_set1_epi64x(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_epi64x(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); @@ -439,6 +447,10 @@ struct avx2_vector { { return _mm256_set1_pd(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_pd(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 8210ef40..15c7c91e 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -46,6 +46,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); @@ -237,6 +241,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); @@ -381,6 +389,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { @@ -548,42 +560,70 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } -X86_SIMD_SORT_INLINE void -avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false) +X86_SIMD_SORT_INLINE void avx512_qsort_fp16(uint16_t *arr, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { + using vtype = zmm_vector; + if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf, uint16_t>( arr, arrsize); } - qsort_, uint16_t>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + if (descending) { + qsort_, uint16_t>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + else { + qsort_, uint16_t>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false) { + using vtype = zmm_vector; + arrsize_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { - qselect_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + if (descending) { + qselect_, uint16_t>( + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); + } + else { + qselect_, uint16_t>( + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); + } } } X86_SIMD_SORT_INLINE void avx512_partial_qsort_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false) { - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending); + avx512_qsort_fp16(arr, k - 1, descending); } #endif // AVX512_QSORT_16BIT diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 96fb965f..8b44e76e 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -58,6 +58,10 @@ struct zmm_vector { { return _mm512_set1_epi32(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { @@ -240,7 +244,11 @@ struct zmm_vector { static reg_t zmm_max() { return _mm512_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm512_set1_epi32(type_min()); + } template static halfreg_t i64gather(__m512i index, void const *base) @@ -424,6 +432,10 @@ struct zmm_vector { { return _mm512_set1_ps(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_ps(type_min()); + } static opmask_t knot_opmask(opmask_t x) { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 1cd4ca1c..68735c33 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -591,7 +591,11 @@ struct zmm_vector { static reg_t zmm_max() { return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm512_set1_epi64(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) @@ -775,6 +779,10 @@ struct zmm_vector { { return _mm512_set1_epi64(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi64(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) @@ -963,6 +971,10 @@ struct zmm_vector { { return _mm512_set1_pd(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_pd(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 7710dc48..130e28a8 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -47,6 +47,10 @@ struct zmm_vector<_Float16> { { return _mm512_set1_ph(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_ph(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); @@ -175,51 +179,86 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem) } template <> -X86_SIMD_SORT_INLINE_ONLY void -replace_inf_with_nan(_Float16 *arr, arrsize_t size, arrsize_t nan_count) +X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, + arrsize_t size, + arrsize_t nan_count, + bool descending) { Fp16Bits val; val.i_ = 0x7c01; - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - arr[ii] = val.f_; - nan_count -= 1; + + if (descending) { + for (arrsize_t ii = 0; nan_count > 0; ++ii) { + arr[ii] = val.f_; + nan_count -= 1; + } + } + else { + for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + arr[ii] = val.f_; + nan_count -= 1; + } } } /* Specialized template function for _Float16 qsort_*/ -template <> +template X86_SIMD_SORT_INLINE_ONLY void avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan) { + using vtype = zmm_vector<_Float16>; + using comparator = + typename std::conditional, + Comparator>::type; + if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf, _Float16>( - arr, arrsize); + nan_count = replace_nan_with_inf(arr, arrsize); } - qsort_, _Float16>( + + qsort_( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } -template <> +template X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - arrsize_t indx_last_elem = arrsize - 1; + using vtype = zmm_vector<_Float16>; + using comparator = + typename std::conditional, + Comparator>::type; + + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if constexpr (descending) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } - if (indx_last_elem >= k) { - qselect_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + + if (index_first_elem <= k && index_last_elem >= k) { + qselect_(arr, + k, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); } } -template <> +template X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1, hasnan); + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1, hasnan); } #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp new file mode 100644 index 00000000..bd742cd4 --- /dev/null +++ b/src/xss-common-comparators.hpp @@ -0,0 +1,127 @@ +#ifndef XSS_COMMON_COMPARATORS +#define XSS_COMMON_COMPARATORS + +template +type_t prev_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, -std::numeric_limits::infinity()); + } + else { + if (value > std::numeric_limits::min()) { return value - 1; } + else { + return value; + } + } +} + +template +type_t next_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, std::numeric_limits::infinity()); + } + else { + if (value < std::numeric_limits::max()) { return value + 1; } + else { + return value; + } + } +} + +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +struct Comparator { + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + using type_t = typename vtype::type_t; + + X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, + const type_t &b) + { + if constexpr (descend) { return comparison_func(b, a); } + else { + return comparison_func(a, b); + } + } + + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) + { + if constexpr (descend) { return vtype::ge(b, a); } + else { + return vtype::ge(a, b); + } + } + + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) + { + if constexpr (descend) { ::COEX(b, a); } + else { + ::COEX(a, b); + } + } + + // Returns a vector of values that would be sorted as far right as possible + // For ascending order, this is the maximum possible value + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() + { + if constexpr (descend) { return vtype::zmm_min(); } + else { + return vtype::zmm_max(); + } + } + + // Returns the value that would be leftmost of the two when sorted + // For ascending order, that is the smaller value + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger) + { + if constexpr (descend) { + UNUSED(smaller); + return larger; + } + else { + UNUSED(larger); + return smaller; + } + } + + // Returns the value that would be rightmost of the two when sorted + // For ascending order, that is the larger value + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger) + { + if constexpr (descend) { + UNUSED(larger); + return smaller; + } + else { + UNUSED(smaller); + return larger; + } + } + + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample + // Try just doing the next largest value greater than this seemingly very common value to seperate them out + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) + { + if constexpr (descend) { return median; } + else { + return next_value(median); + } + } + + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample + // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) + { + if constexpr (descend) { return prev_value(median); } + else { + return median; + } + } +}; + +#endif // XSS_COMMON_COMPARATORS diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 3d5e20ea..02522b50 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -37,6 +37,7 @@ #include "xss-common-includes.h" #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" +#include "xss-common-comparators.hpp" template bool is_a_nan(T elem) @@ -98,17 +99,32 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) } template -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count) +X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, + arrsize_t size, + arrsize_t nan_count, + bool descending = false) { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - if constexpr (std::is_floating_point_v) { - arr[ii] = std::numeric_limits::quiet_NaN(); + if (descending) { + for (arrsize_t ii = 0; nan_count > 0; ++ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; } - else { - arr[ii] = 0xFFFF; + } + else { + for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; } - nan_count -= 1; } } @@ -137,6 +153,26 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size) return size - count - 1; } +/* + * Sort all the NAN's to start of the array and return the index of the first elem + * in the array which is not a nan + */ +template +X86_SIMD_SORT_INLINE arrsize_t move_nans_to_start_of_array(T *arr, + arrsize_t size) +{ + arrsize_t count = 0; + + for (arrsize_t i = 0; i < size; i++) { + if (is_a_nan(arr[i])) { + std::swap(arr[count], arr[i]); + count++; + } + } + + return count; +} + template X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b) { @@ -181,6 +217,7 @@ int avx512_double_compressstore(type_t *left_addr, // Generic function dispatches to AVX2 or AVX512 code template X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, @@ -190,10 +227,11 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, reg_t &smallest_vec, reg_t &biggest_vec) { - typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); + typename vtype::opmask_t right_mask + = comparator::PartitionComparator(curr_vec, pivot_vec); - int amount_ge_pivot - = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); + int amount_ge_pivot = vtype::double_compressstore( + l_store, r_store, right_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); @@ -205,7 +243,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, * Parition an array based on the pivot and returns the index of the * first element that is greater than or equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, arrsize_t left, arrsize_t right, @@ -217,7 +255,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { + if (!comparator::STDSortComparator(arr[left], pivot)) { std::swap(arr[left], arr[--right]); } else { @@ -238,13 +276,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, arrsize_t unpartitioned = right - left - vtype::numlanes; arrsize_t l_store = left; - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec, - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -277,34 +315,35 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, left += vtype::numlanes; } // partition the current vector and save it on both sides of the array - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - curr_vec, - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + curr_vec, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } /* partition and save vec_left and vec_right */ arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_left, - pivot_vec, - min_vec, - max_vec); + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec_left, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; - amount_ge_pivot = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_right, - pivot_vec, - min_vec, - max_vec); + amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec_right, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; @@ -314,6 +353,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, } template X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, @@ -324,19 +364,21 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, type_t *biggest) { if constexpr (num_unroll == 0) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition( + arr, left, right, pivot, smallest, biggest); } /* Use regular partition for smaller arrays */ if (right - left < 3 * num_unroll * vtype::numlanes) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition( + arr, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes, shortening the array */ for (int32_t i = ((right - left) % (vtype::numlanes)); i > 0; --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { + if (!comparator::STDSortComparator(arr[left], pivot)) { std::swap(arr[left], arr[--right]); } else { @@ -418,13 +460,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, * */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - curr_vec[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + curr_vec[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -433,25 +475,25 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, /* partition and save vec_left[num_unroll] and vec_right[num_unroll] */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_left[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_left[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_right[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_right[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -459,13 +501,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, /* partition and save vec_align[vecsToPartition] */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < vecsToPartition; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_align[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_align[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -478,7 +520,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, template void sort_n(typename vtype::type_t *arr, int N); -template +template static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) { @@ -486,19 +528,20 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); + std::sort(arr + left, arr + right + 1, comparator::STDSortComparator); return; } /* * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold */ if (right + 1 - left <= vtype::network_sort_threshold) { - sort_n( + sort_n( arr + left, (int32_t)(right + 1 - left)); return; } - auto pivot_result = get_pivot_smart(arr, left, right); + auto pivot_result + = get_pivot_smart(arr, left, right); type_t pivot = pivot_result.pivot; if (pivot_result.result == pivot_result_t::Sorted) { return; } @@ -506,18 +549,23 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index - = partition_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); + arrsize_t pivot_index = partition_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); if (pivot_result.result == pivot_result_t::Only2Values) { return; } - if (pivot != smallest) - qsort_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) qsort_(arr, pivot_index, right, max_iters - 1); + type_t leftmostValue = comparator::leftmost(smallest, biggest); + type_t rightmostValue = comparator::rightmost(smallest, biggest); + + if (pivot != leftmostValue) + qsort_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != rightmostValue) + qsort_(arr, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void qselect_(type_t *arr, arrsize_t pos, arrsize_t left, @@ -528,97 +576,144 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); + std::sort(arr + left, arr + right + 1, comparator::STDSortComparator); return; } /* * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold */ if (right + 1 - left <= vtype::network_sort_threshold) { - sort_n( + sort_n( arr + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot(arr, left, right); + type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index - = partition_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); + arrsize_t pivot_index = partition_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_(arr, pos, pivot_index, right, max_iters - 1); + type_t leftmostValue = comparator::leftmost(smallest, biggest); + type_t rightmostValue = comparator::rightmost(smallest, biggest); + + if ((pivot != leftmostValue) && (pos < pivot_index)) + qselect_( + arr, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != rightmostValue) && (pos >= pivot_index)) + qselect_( + arr, pos, pivot_index, right, max_iters - 1); } // Quicksort routines: -template +template X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) { + using comparator = + typename std::conditional, + Comparator>::type; + if (arrsize > 1) { + arrsize_t nan_count = 0; if constexpr (std::is_floating_point_v) { - arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } - else { - UNUSED(hasnan); - qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } + + UNUSED(hasnan); + qsort_( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } // Quick select methods -template +template X86_SIMD_SORT_INLINE void xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - arrsize_t indx_last_elem = arrsize - 1; + using comparator = + typename std::conditional, + Comparator>::type; + + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; + if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if constexpr (descending) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); + } } } + UNUSED(hasnan); - if (indx_last_elem >= k) { - qselect_( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + if (index_first_elem <= k && index_last_elem >= k) { + qselect_(arr, + k, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); } } // Partial sort methods: -template +template X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - xss_qselect(arr, k - 1, arrsize, hasnan); - xss_qsort(arr, k - 1, hasnan); + xss_qselect(arr, k - 1, arrsize, hasnan); + xss_qsort(arr, k - 1, hasnan); } #define DEFINE_METHODS(ISA, VTYPE) \ template \ - X86_SIMD_SORT_INLINE void ISA##_qsort( \ - T *arr, arrsize_t size, bool hasnan = false) \ + X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ - xss_qsort(arr, size, hasnan); \ + if (descending) { xss_qsort(arr, size, hasnan); } \ + else { \ + xss_qsort(arr, size, hasnan); \ + } \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_qselect( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ - xss_qselect(arr, k, size, hasnan); \ + if (descending) { xss_qselect(arr, k, size, hasnan); } \ + else { \ + xss_qselect(arr, k, size, hasnan); \ + } \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ - xss_partial_qsort(arr, k, size, hasnan); \ + if (descending) { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } \ + else { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } \ } DEFINE_METHODS(avx512, zmm_vector) diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index d883004a..dd299507 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -7,7 +7,10 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); -template +template X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) { if constexpr (numVecs == 1) { @@ -15,19 +18,19 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) return; } else if constexpr (numVecs == 2) { - COEX(regs[0], regs[1]); + comparator::COEX(regs[0], regs[1]); } else if constexpr (numVecs == 4) { - optimal_sort_4(regs); + optimal_sort_4(regs); } else if constexpr (numVecs == 8) { - optimal_sort_8(regs); + optimal_sort_8(regs); } else if constexpr (numVecs == 16) { - optimal_sort_16(regs); + optimal_sort_16(regs); } else if constexpr (numVecs == 32) { - optimal_sort_32(regs); + optimal_sort_32(regs); } else { static_assert(numVecs == -1, "should not reach here"); @@ -53,7 +56,11 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) * merge_n<8> = [a,a,a,a,b,b,b,b] */ -template +template X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) { using reg_t = typename vtype::reg_t; @@ -69,7 +76,7 @@ X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) for (int i = 0; i < numVecs; i++) { reg_t &v = reg[i]; reg_t rev = swizzle::template reverse_n(v); - COEX(rev, v); + comparator::COEX(rev, v); v = swizzle::template merge_n(v, rev); } } @@ -79,15 +86,16 @@ X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) for (int i = 0; i < numVecs; i++) { reg_t &v = reg[i]; reg_t swap = swizzle::template swap_n(v); - COEX(swap, v); + comparator::COEX(swap, v); v = swizzle::template merge_n(v, swap); } } - internal_merge_n_vec(reg); + internal_merge_n_vec(reg); } } template @@ -107,27 +115,30 @@ X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) // Do compare exchanges X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { - COEX(regs[i], regs[numVecs - 1 - i]); + comparator::COEX(regs[i], regs[numVecs - 1 - i]); } - merge_substep_n_vec(regs); - merge_substep_n_vec(regs + numVecs / 2); + merge_substep_n_vec(regs); + merge_substep_n_vec(regs + + numVecs / 2); } template X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) { // Do cross vector merges - merge_substep_n_vec(regs); + merge_substep_n_vec(regs); // Do internal vector merges - internal_merge_n_vec(regs); + internal_merge_n_vec(regs); } template @@ -138,30 +149,36 @@ X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) return; } else { - merge_step_n_vec(regs); - merge_n_vec(regs); + merge_step_n_vec(regs); + merge_n_vec(regs); } } -template +template X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) { /* Run the initial sorting network to sort the columns of the [numVecs x * num_lanes] matrix */ - bitonic_sort_n_vec(vecs); + bitonic_sort_n_vec(vecs); // Merge the vectors using bitonic merging networks - merge_n_vec(vecs); + merge_n_vec(vecs); } -template +template X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) { static_assert(numVecs > 0, "numVecs should be > 0"); if constexpr (numVecs > 1) { if (N * 2 <= numVecs * vtype::numlanes) { - sort_n_vec(arr, N); + sort_n_vec(arr, N); return; } } @@ -186,11 +203,12 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - vecs[i] = vtype::mask_loadu( - vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes); + vecs[i] = vtype::mask_loadu(comparator::rightmostPossibleVec(), + ioMasks[j], + arr + i * vtype::numlanes); } - sort_vectors(vecs); + sort_vectors(vecs); // Unmasked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) @@ -204,7 +222,7 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) } } -template +template X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) { constexpr int numVecs = maxN / vtype::numlanes; @@ -213,6 +231,6 @@ X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) static_assert(powerOfTwo == true && isMultiple == true, "maxN must be vtype::numlanes times a power of 2"); - sort_n_vec(arr, N); + sort_n_vec(arr, N); } #endif diff --git a/src/xss-optimal-networks.hpp b/src/xss-optimal-networks.hpp index bffe493d..e722b1f1 100644 --- a/src/xss-optimal-networks.hpp +++ b/src/xss-optimal-networks.hpp @@ -1,323 +1,328 @@ // All of these sources files are generated from the optimal networks described in // https://bertdobbelaere.github.io/sorting_networks.html -template -X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); - -template +template X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) { - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); - COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[1], vecs[2]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) { - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - - COEX(vecs[0], vecs[4]); - COEX(vecs[1], vecs[5]); - COEX(vecs[2], vecs[6]); - COEX(vecs[3], vecs[7]); - - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[7]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[5]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[3], vecs[6]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[5]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) { - COEX(vecs[0], vecs[13]); - COEX(vecs[1], vecs[12]); - COEX(vecs[2], vecs[15]); - COEX(vecs[3], vecs[14]); - COEX(vecs[4], vecs[8]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[11]); - COEX(vecs[9], vecs[10]); - - COEX(vecs[0], vecs[5]); - COEX(vecs[1], vecs[7]); - COEX(vecs[2], vecs[9]); - COEX(vecs[3], vecs[4]); - COEX(vecs[6], vecs[13]); - COEX(vecs[8], vecs[14]); - COEX(vecs[10], vecs[15]); - COEX(vecs[11], vecs[12]); - - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[8]); - COEX(vecs[7], vecs[9]); - COEX(vecs[10], vecs[11]); - COEX(vecs[12], vecs[13]); - COEX(vecs[14], vecs[15]); - - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[10]); - COEX(vecs[5], vecs[11]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[12], vecs[14]); - COEX(vecs[13], vecs[15]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[12]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - COEX(vecs[8], vecs[10]); - COEX(vecs[9], vecs[11]); - COEX(vecs[13], vecs[14]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[2], vecs[6]); - COEX(vecs[5], vecs[8]); - COEX(vecs[7], vecs[10]); - COEX(vecs[9], vecs[13]); - COEX(vecs[11], vecs[14]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[6]); - COEX(vecs[9], vecs[12]); - COEX(vecs[11], vecs[13]); - - COEX(vecs[3], vecs[5]); - COEX(vecs[6], vecs[8]); - COEX(vecs[7], vecs[9]); - COEX(vecs[10], vecs[12]); - - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[12]); - - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[0], vecs[13]); + comparator::COEX(vecs[1], vecs[12]); + comparator::COEX(vecs[2], vecs[15]); + comparator::COEX(vecs[3], vecs[14]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[10]); + + comparator::COEX(vecs[0], vecs[5]); + comparator::COEX(vecs[1], vecs[7]); + comparator::COEX(vecs[2], vecs[9]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[6], vecs[13]); + comparator::COEX(vecs[8], vecs[14]); + comparator::COEX(vecs[10], vecs[15]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[10]); + comparator::COEX(vecs[5], vecs[11]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[13], vecs[14]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[5], vecs[8]); + comparator::COEX(vecs[7], vecs[10]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[11], vecs[14]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + comparator::COEX(vecs[9], vecs[12]); + comparator::COEX(vecs[11], vecs[13]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) { - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[10], vecs[11]); - COEX(vecs[12], vecs[13]); - COEX(vecs[14], vecs[15]); - COEX(vecs[16], vecs[17]); - COEX(vecs[18], vecs[19]); - COEX(vecs[20], vecs[21]); - COEX(vecs[22], vecs[23]); - COEX(vecs[24], vecs[25]); - COEX(vecs[26], vecs[27]); - COEX(vecs[28], vecs[29]); - COEX(vecs[30], vecs[31]); - - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - COEX(vecs[8], vecs[10]); - COEX(vecs[9], vecs[11]); - COEX(vecs[12], vecs[14]); - COEX(vecs[13], vecs[15]); - COEX(vecs[16], vecs[18]); - COEX(vecs[17], vecs[19]); - COEX(vecs[20], vecs[22]); - COEX(vecs[21], vecs[23]); - COEX(vecs[24], vecs[26]); - COEX(vecs[25], vecs[27]); - COEX(vecs[28], vecs[30]); - COEX(vecs[29], vecs[31]); - - COEX(vecs[0], vecs[4]); - COEX(vecs[1], vecs[5]); - COEX(vecs[2], vecs[6]); - COEX(vecs[3], vecs[7]); - COEX(vecs[8], vecs[12]); - COEX(vecs[9], vecs[13]); - COEX(vecs[10], vecs[14]); - COEX(vecs[11], vecs[15]); - COEX(vecs[16], vecs[20]); - COEX(vecs[17], vecs[21]); - COEX(vecs[18], vecs[22]); - COEX(vecs[19], vecs[23]); - COEX(vecs[24], vecs[28]); - COEX(vecs[25], vecs[29]); - COEX(vecs[26], vecs[30]); - COEX(vecs[27], vecs[31]); - - COEX(vecs[0], vecs[8]); - COEX(vecs[1], vecs[9]); - COEX(vecs[2], vecs[10]); - COEX(vecs[3], vecs[11]); - COEX(vecs[4], vecs[12]); - COEX(vecs[5], vecs[13]); - COEX(vecs[6], vecs[14]); - COEX(vecs[7], vecs[15]); - COEX(vecs[16], vecs[24]); - COEX(vecs[17], vecs[25]); - COEX(vecs[18], vecs[26]); - COEX(vecs[19], vecs[27]); - COEX(vecs[20], vecs[28]); - COEX(vecs[21], vecs[29]); - COEX(vecs[22], vecs[30]); - COEX(vecs[23], vecs[31]); - - COEX(vecs[0], vecs[16]); - COEX(vecs[1], vecs[8]); - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[12]); - COEX(vecs[5], vecs[10]); - COEX(vecs[6], vecs[9]); - COEX(vecs[7], vecs[14]); - COEX(vecs[11], vecs[13]); - COEX(vecs[15], vecs[31]); - COEX(vecs[17], vecs[24]); - COEX(vecs[18], vecs[20]); - COEX(vecs[19], vecs[28]); - COEX(vecs[21], vecs[26]); - COEX(vecs[22], vecs[25]); - COEX(vecs[23], vecs[30]); - COEX(vecs[27], vecs[29]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[5]); - COEX(vecs[4], vecs[8]); - COEX(vecs[6], vecs[22]); - COEX(vecs[7], vecs[11]); - COEX(vecs[9], vecs[25]); - COEX(vecs[10], vecs[12]); - COEX(vecs[13], vecs[14]); - COEX(vecs[17], vecs[18]); - COEX(vecs[19], vecs[21]); - COEX(vecs[20], vecs[24]); - COEX(vecs[23], vecs[27]); - COEX(vecs[26], vecs[28]); - COEX(vecs[29], vecs[30]); - - COEX(vecs[1], vecs[17]); - COEX(vecs[2], vecs[18]); - COEX(vecs[3], vecs[19]); - COEX(vecs[4], vecs[20]); - COEX(vecs[5], vecs[10]); - COEX(vecs[7], vecs[23]); - COEX(vecs[8], vecs[24]); - COEX(vecs[11], vecs[27]); - COEX(vecs[12], vecs[28]); - COEX(vecs[13], vecs[29]); - COEX(vecs[14], vecs[30]); - COEX(vecs[21], vecs[26]); - - COEX(vecs[3], vecs[17]); - COEX(vecs[4], vecs[16]); - COEX(vecs[5], vecs[21]); - COEX(vecs[6], vecs[18]); - COEX(vecs[7], vecs[9]); - COEX(vecs[8], vecs[20]); - COEX(vecs[10], vecs[26]); - COEX(vecs[11], vecs[23]); - COEX(vecs[13], vecs[25]); - COEX(vecs[14], vecs[28]); - COEX(vecs[15], vecs[27]); - COEX(vecs[22], vecs[24]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[3], vecs[8]); - COEX(vecs[5], vecs[16]); - COEX(vecs[7], vecs[17]); - COEX(vecs[9], vecs[21]); - COEX(vecs[10], vecs[22]); - COEX(vecs[11], vecs[19]); - COEX(vecs[12], vecs[20]); - COEX(vecs[14], vecs[24]); - COEX(vecs[15], vecs[26]); - COEX(vecs[23], vecs[28]); - COEX(vecs[27], vecs[30]); - - COEX(vecs[2], vecs[5]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[18]); - COEX(vecs[11], vecs[17]); - COEX(vecs[12], vecs[16]); - COEX(vecs[13], vecs[22]); - COEX(vecs[14], vecs[20]); - COEX(vecs[15], vecs[19]); - COEX(vecs[23], vecs[24]); - COEX(vecs[26], vecs[29]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[6], vecs[12]); - COEX(vecs[9], vecs[16]); - COEX(vecs[10], vecs[11]); - COEX(vecs[13], vecs[17]); - COEX(vecs[14], vecs[18]); - COEX(vecs[15], vecs[22]); - COEX(vecs[19], vecs[25]); - COEX(vecs[20], vecs[21]); - COEX(vecs[27], vecs[29]); - - COEX(vecs[5], vecs[6]); - COEX(vecs[8], vecs[12]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[13]); - COEX(vecs[14], vecs[16]); - COEX(vecs[15], vecs[17]); - COEX(vecs[18], vecs[20]); - COEX(vecs[19], vecs[23]); - COEX(vecs[21], vecs[22]); - COEX(vecs[25], vecs[26]); - - COEX(vecs[3], vecs[5]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[10], vecs[12]); - COEX(vecs[11], vecs[14]); - COEX(vecs[13], vecs[16]); - COEX(vecs[15], vecs[18]); - COEX(vecs[17], vecs[20]); - COEX(vecs[19], vecs[21]); - COEX(vecs[22], vecs[23]); - COEX(vecs[24], vecs[25]); - COEX(vecs[26], vecs[28]); - - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[12]); - COEX(vecs[13], vecs[14]); - COEX(vecs[15], vecs[16]); - COEX(vecs[17], vecs[18]); - COEX(vecs[19], vecs[20]); - COEX(vecs[21], vecs[22]); - COEX(vecs[23], vecs[24]); - COEX(vecs[25], vecs[26]); - COEX(vecs[27], vecs[28]); + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + comparator::COEX(vecs[16], vecs[17]); + comparator::COEX(vecs[18], vecs[19]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[27]); + comparator::COEX(vecs[28], vecs[29]); + comparator::COEX(vecs[30], vecs[31]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + comparator::COEX(vecs[16], vecs[18]); + comparator::COEX(vecs[17], vecs[19]); + comparator::COEX(vecs[20], vecs[22]); + comparator::COEX(vecs[21], vecs[23]); + comparator::COEX(vecs[24], vecs[26]); + comparator::COEX(vecs[25], vecs[27]); + comparator::COEX(vecs[28], vecs[30]); + comparator::COEX(vecs[29], vecs[31]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[10], vecs[14]); + comparator::COEX(vecs[11], vecs[15]); + comparator::COEX(vecs[16], vecs[20]); + comparator::COEX(vecs[17], vecs[21]); + comparator::COEX(vecs[18], vecs[22]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[24], vecs[28]); + comparator::COEX(vecs[25], vecs[29]); + comparator::COEX(vecs[26], vecs[30]); + comparator::COEX(vecs[27], vecs[31]); + + comparator::COEX(vecs[0], vecs[8]); + comparator::COEX(vecs[1], vecs[9]); + comparator::COEX(vecs[2], vecs[10]); + comparator::COEX(vecs[3], vecs[11]); + comparator::COEX(vecs[4], vecs[12]); + comparator::COEX(vecs[5], vecs[13]); + comparator::COEX(vecs[6], vecs[14]); + comparator::COEX(vecs[7], vecs[15]); + comparator::COEX(vecs[16], vecs[24]); + comparator::COEX(vecs[17], vecs[25]); + comparator::COEX(vecs[18], vecs[26]); + comparator::COEX(vecs[19], vecs[27]); + comparator::COEX(vecs[20], vecs[28]); + comparator::COEX(vecs[21], vecs[29]); + comparator::COEX(vecs[22], vecs[30]); + comparator::COEX(vecs[23], vecs[31]); + + comparator::COEX(vecs[0], vecs[16]); + comparator::COEX(vecs[1], vecs[8]); + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[6], vecs[9]); + comparator::COEX(vecs[7], vecs[14]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[15], vecs[31]); + comparator::COEX(vecs[17], vecs[24]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[28]); + comparator::COEX(vecs[21], vecs[26]); + comparator::COEX(vecs[22], vecs[25]); + comparator::COEX(vecs[23], vecs[30]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[6], vecs[22]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[25]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[20], vecs[24]); + comparator::COEX(vecs[23], vecs[27]); + comparator::COEX(vecs[26], vecs[28]); + comparator::COEX(vecs[29], vecs[30]); + + comparator::COEX(vecs[1], vecs[17]); + comparator::COEX(vecs[2], vecs[18]); + comparator::COEX(vecs[3], vecs[19]); + comparator::COEX(vecs[4], vecs[20]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[7], vecs[23]); + comparator::COEX(vecs[8], vecs[24]); + comparator::COEX(vecs[11], vecs[27]); + comparator::COEX(vecs[12], vecs[28]); + comparator::COEX(vecs[13], vecs[29]); + comparator::COEX(vecs[14], vecs[30]); + comparator::COEX(vecs[21], vecs[26]); + + comparator::COEX(vecs[3], vecs[17]); + comparator::COEX(vecs[4], vecs[16]); + comparator::COEX(vecs[5], vecs[21]); + comparator::COEX(vecs[6], vecs[18]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[8], vecs[20]); + comparator::COEX(vecs[10], vecs[26]); + comparator::COEX(vecs[11], vecs[23]); + comparator::COEX(vecs[13], vecs[25]); + comparator::COEX(vecs[14], vecs[28]); + comparator::COEX(vecs[15], vecs[27]); + comparator::COEX(vecs[22], vecs[24]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[8]); + comparator::COEX(vecs[5], vecs[16]); + comparator::COEX(vecs[7], vecs[17]); + comparator::COEX(vecs[9], vecs[21]); + comparator::COEX(vecs[10], vecs[22]); + comparator::COEX(vecs[11], vecs[19]); + comparator::COEX(vecs[12], vecs[20]); + comparator::COEX(vecs[14], vecs[24]); + comparator::COEX(vecs[15], vecs[26]); + comparator::COEX(vecs[23], vecs[28]); + comparator::COEX(vecs[27], vecs[30]); + + comparator::COEX(vecs[2], vecs[5]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[18]); + comparator::COEX(vecs[11], vecs[17]); + comparator::COEX(vecs[12], vecs[16]); + comparator::COEX(vecs[13], vecs[22]); + comparator::COEX(vecs[14], vecs[20]); + comparator::COEX(vecs[15], vecs[19]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[26], vecs[29]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[6], vecs[12]); + comparator::COEX(vecs[9], vecs[16]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[13], vecs[17]); + comparator::COEX(vecs[14], vecs[18]); + comparator::COEX(vecs[15], vecs[22]); + comparator::COEX(vecs[19], vecs[25]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[14], vecs[16]); + comparator::COEX(vecs[15], vecs[17]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[25], vecs[26]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[11], vecs[14]); + comparator::COEX(vecs[13], vecs[16]); + comparator::COEX(vecs[15], vecs[18]); + comparator::COEX(vecs[17], vecs[20]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[28]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[15], vecs[16]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[20]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[25], vecs[26]); + comparator::COEX(vecs[27], vecs[28]); } diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 59dc0489..6ce0b887 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -2,6 +2,7 @@ #define XSS_PIVOT_SELECTION #include "xss-network-qsort.hpp" +#include "xss-common-comparators.hpp" enum class pivot_result_t : int { Normal, Sorted, Only2Values }; @@ -19,21 +20,6 @@ struct pivot_results { } }; -template -type_t next_value(type_t value) -{ - // TODO this probably handles non-native float16 wrong - if constexpr (std::is_floating_point::value) { - return std::nextafter(value, std::numeric_limits::infinity()); - } - else { - if (value < std::numeric_limits::max()) { return value + 1; } - else { - return value; - } - } -} - template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); @@ -98,14 +84,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, return data[vtype::numlanes / 2]; } -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_near_constant(type_t *arr, type_t commonValue, const arrsize_t left, const arrsize_t right); -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) { @@ -127,7 +113,9 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) } // Sort the samples - sort_vectors(vecs); + // Note that this intentionally uses the AscendingComparator + // instead of the provided comparator + sort_vectors, numVecs>(vecs); type_t samples[N]; for (int i = 0; i < numVecs; i++) { @@ -141,21 +129,24 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) if (smallest == largest) { // We have a very unlucky sample, or the array is constant / near constant // Run a special function meant to deal with this situation - return get_pivot_near_constant(arr, median, left, right); + return get_pivot_near_constant( + arr, median, left, right); } else if (median != smallest && median != largest) { // We have a normal sample; use it's median return pivot_results(median); } else if (median == smallest) { - // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample - // Try just doing the next largest value greater than this seemingly very common value to seperate them out - return pivot_results(next_value(median)); + // We will either return the median or the next value larger than the median, + // depending on the comparator (see xss-common-comparators.hpp for more details) + return pivot_results( + comparator::choosePivotMedianIsSmallest(median)); } else if (median == largest) { - // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample - // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - return pivot_results(median); + // We will either return the median or the next value smaller than the median, + // depending on the comparator (see xss-common-comparators.hpp for more details) + return pivot_results( + comparator::choosePivotMedianIsLargest(median)); } else { // Should be unreachable @@ -167,7 +158,7 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) } // Handles the case where we seem to have a near-constant array, since our sample of the array was constant -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_near_constant(type_t *arr, type_t commonValue, @@ -228,9 +219,11 @@ get_pivot_near_constant(type_t *arr, if (index == right + 1) { // The array contains only 2 values // We must pick the larger one, else the right partition is empty - // We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the larger value + // (note that larger is determined using the provided comparator, so it might actually be the smaller one) + // We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the chosen value // TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere - type_t pivot = std::max(value1, commonValue, comparison_func); + type_t pivot + = std::max(value1, commonValue, comparator::STDSortComparator); return pivot_results(pivot, pivot_result_t::Only2Values); } diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index 9638387f..4fdb87fc 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -46,12 +46,22 @@ template void IS_ARR_PARTITIONED(std::vector arr, size_t k, T true_kth, - std::string type) + std::string type, + bool descending = false) { - auto cmp_eq = compare>(); - auto cmp_less = compare>(); - auto cmp_leq = compare>(); - auto cmp_geq = compare>(); + std::function cmp_eq, cmp_less, cmp_leq, cmp_geq; + cmp_eq = compare>(); + + if (!descending) { + cmp_less = compare>(); + cmp_leq = compare>(); + cmp_geq = compare>(); + } + else { + cmp_less = compare>(); + cmp_leq = compare>(); + cmp_geq = compare>(); + } // 1) arr[k] == sorted[k]; use memcmp to handle nan if (!cmp_eq(arr[k], true_kth)) { diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index d1428ef8..5d4ba587 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -27,18 +27,44 @@ class simdsort : public ::testing::Test { TYPED_TEST_SUITE_P(simdsort); -TYPED_TEST_P(simdsort, test_qsort) +TYPED_TEST_P(simdsort, test_qsort_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); x86simdsort::qsort(arr.data(), arr.size(), hasnan); IS_SORTED(sortedarr, arr, type); + + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_qsort_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + std::vector basearr = get_array(type, size); + + // Descending order + std::vector arr = basearr; + std::vector sortedarr = arr; + std::sort(sortedarr.begin(), + sortedarr.end(), + compare>()); + x86simdsort::qsort(arr.data(), arr.size(), hasnan, true); + IS_SORTED(sortedarr, arr, type); + arr.clear(); sortedarr.clear(); } @@ -63,13 +89,16 @@ TYPED_TEST_P(simdsort, test_argsort) } } -TYPED_TEST_P(simdsort, test_qselect) +TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::nth_element(sortedarr.begin(), sortedarr.begin() + k, @@ -77,6 +106,31 @@ TYPED_TEST_P(simdsort, test_qselect) compare>()); x86simdsort::qselect(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); + + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_qselect_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + std::vector basearr = get_array(type, size); + + // Descending order + std::vector arr = basearr; + std::vector sortedarr = arr; + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare>()); + x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); + arr.clear(); sortedarr.clear(); } @@ -103,20 +157,48 @@ TYPED_TEST_P(simdsort, test_argselect) } } -TYPED_TEST_P(simdsort, test_partial_qsort) +TYPED_TEST_P(simdsort, test_partial_qsort_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { // k should be at least 1 size_t k = std::max((size_t)1, rand() % size); - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); + + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_partial_qsort_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + // k should be at least 1 + size_t k = std::max((size_t)1, rand() % size); + std::vector basearr = get_array(type, size); + + // Descending order + std::vector arr = basearr; + std::vector sortedarr = arr; + std::sort(sortedarr.begin(), + sortedarr.end(), + compare>()); + x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan, true); + IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); + arr.clear(); sortedarr.clear(); } @@ -157,11 +239,14 @@ TYPED_TEST_P(simdsort, test_comparator) } REGISTER_TYPED_TEST_SUITE_P(simdsort, - test_qsort, + test_qsort_ascending, + test_qsort_descending, test_argsort, test_argselect, - test_qselect, - test_partial_qsort, + test_qselect_ascending, + test_qselect_descending, + test_partial_qsort_ascending, + test_partial_qsort_descending, test_comparator); using QSortTestTypes = testing::Types