@@ -612,91 +612,71 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
612612}
613613
614614// Quicksort routines:
615- template <typename vtype, typename T>
616- X86_SIMD_SORT_INLINE void
617- xss_qsort (T *arr, arrsize_t arrsize, bool hasnan, bool descending)
615+ template <typename vtype, typename T, bool descending = false >
616+ X86_SIMD_SORT_INLINE void xss_qsort (T *arr, arrsize_t arrsize, bool hasnan)
618617{
618+ using comparator =
619+ typename std::conditional<descending,
620+ DescendingComparator<vtype>,
621+ AscendingComparator<vtype>>::type;
622+
619623 if (arrsize > 1 ) {
624+ arrsize_t nan_count = 0 ;
620625 if constexpr (std::is_floating_point_v<T>) {
621- arrsize_t nan_count = 0 ;
622626 if (UNLIKELY (hasnan)) {
623627 nan_count = replace_nan_with_inf<vtype>(arr, arrsize);
624628 }
625- if (descending) {
626- qsort_<vtype, DescendingComparator<vtype>, T>(
627- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
628- }
629- else {
630- qsort_<vtype, AscendingComparator<vtype>, T>(
631- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
632- }
633- replace_inf_with_nan (arr, arrsize, nan_count, descending);
634- }
635- else {
636- UNUSED (hasnan);
637- if (descending) {
638- qsort_<vtype, DescendingComparator<vtype>, T>(
639- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
640- }
641- else {
642- qsort_<vtype, AscendingComparator<vtype>, T>(
643- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
644- }
645629 }
630+
631+ UNUSED (hasnan);
632+ qsort_<vtype, comparator, T>(
633+ arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
634+
635+ replace_inf_with_nan (arr, arrsize, nan_count, descending);
646636 }
647637}
648638
649639// Quick select methods
650- template <typename vtype, typename T>
651- X86_SIMD_SORT_INLINE void xss_qselect (
652- T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending )
640+ template <typename vtype, typename T, bool descending = false >
641+ X86_SIMD_SORT_INLINE void
642+ xss_qselect ( T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
653643{
654- if (descending) {
655- arrsize_t index_first_elem = 0 ;
656- if constexpr (std::is_floating_point_v<T>) {
657- if (UNLIKELY (hasnan)) {
658- index_first_elem = move_nans_to_start_of_array (arr, arrsize);
659- }
660- }
644+ using comparator =
645+ typename std::conditional<descending,
646+ DescendingComparator<vtype>,
647+ AscendingComparator<vtype>>::type;
661648
662- arrsize_t size_without_nans = arrsize - index_first_elem;
649+ arrsize_t index_first_elem = 0 ;
650+ arrsize_t index_last_elem = arrsize - 1 ;
663651
664- UNUSED (hasnan);
665- if (index_first_elem <= k) {
666- qselect_<vtype, DescendingComparator<vtype>, T>(
667- arr,
668- k,
669- index_first_elem,
670- arrsize - 1 ,
671- 2 * (arrsize_t )log2 (size_without_nans));
672- }
673- }
674- else {
675- arrsize_t indx_last_elem = arrsize - 1 ;
676- if constexpr (std::is_floating_point_v<T>) {
677- if (UNLIKELY (hasnan)) {
678- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
652+ if constexpr (std::is_floating_point_v<T>) {
653+ if (UNLIKELY (hasnan)) {
654+ if constexpr (descending) {
655+ index_first_elem = move_nans_to_start_of_array (arr, arrsize);
656+ }
657+ else {
658+ index_last_elem = move_nans_to_end_of_array (arr, arrsize);
679659 }
680- }
681- UNUSED (hasnan);
682- if (indx_last_elem >= k) {
683- qselect_<vtype, AscendingComparator<vtype>, T>(
684- arr,
685- k,
686- 0 ,
687- indx_last_elem,
688- 2 * (arrsize_t )log2 (indx_last_elem));
689660 }
690661 }
662+
663+ UNUSED (hasnan);
664+ if (index_first_elem <= k && index_last_elem >= k) {
665+ qselect_<vtype, comparator, T>(arr,
666+ k,
667+ index_first_elem,
668+ index_last_elem,
669+ 2 * (arrsize_t )log2 (arrsize));
670+ }
691671}
692672
693673// Partial sort methods:
694- template <typename vtype, typename T>
695- X86_SIMD_SORT_INLINE void xss_partial_qsort (
696- T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending )
674+ template <typename vtype, typename T, bool descending = false >
675+ X86_SIMD_SORT_INLINE void
676+ xss_partial_qsort ( T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
697677{
698- xss_qselect<vtype, T>(arr, k - 1 , arrsize, hasnan, descending );
699- xss_qsort<vtype, T>(arr, k - 1 , hasnan, descending );
678+ xss_qselect<vtype, T, descending >(arr, k - 1 , arrsize, hasnan);
679+ xss_qsort<vtype, T, descending >(arr, k - 1 , hasnan);
700680}
701681
702682#define DEFINE_METHODS (ISA, VTYPE ) \
@@ -706,7 +686,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
706686 bool hasnan = false , \
707687 bool descending = false ) \
708688 { \
709- xss_qsort<VTYPE, T>(arr, size, hasnan, descending); \
689+ if (descending) { xss_qsort<VTYPE, T, true >(arr, size, hasnan); } \
690+ else { \
691+ xss_qsort<VTYPE, T, false >(arr, size, hasnan); \
692+ } \
710693 } \
711694 template <typename T> \
712695 X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \
@@ -715,7 +698,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
715698 bool hasnan = false , \
716699 bool descending = false ) \
717700 { \
718- xss_qselect<VTYPE, T>(arr, k, size, hasnan, descending); \
701+ if (descending) { xss_qselect<VTYPE, T, true >(arr, k, size, hasnan); } \
702+ else { \
703+ xss_qselect<VTYPE, T, false >(arr, k, size, hasnan); \
704+ } \
719705 } \
720706 template <typename T> \
721707 X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \
@@ -724,7 +710,12 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
724710 bool hasnan = false , \
725711 bool descending = false ) \
726712 { \
727- xss_partial_qsort<VTYPE, T>(arr, k, size, hasnan, descending); \
713+ if (descending) { \
714+ xss_partial_qsort<VTYPE, T, true >(arr, k, size, hasnan); \
715+ } \
716+ else { \
717+ xss_partial_qsort<VTYPE, T, false >(arr, k, size, hasnan); \
718+ } \
728719 }
729720
730721DEFINE_METHODS (avx512, zmm_vector<T>)
0 commit comments