Skip to content

Commit 4e41904

Browse files
committed
Simplified logic for ascending/descending comparators
1 parent aadca4c commit 4e41904

File tree

3 files changed

+106
-121
lines changed

3 files changed

+106
-121
lines changed

lib/x86simdsort-spr.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ namespace avx512 {
77
template <>
88
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
99
{
10-
avx512_qsort(arr, size, hasnan, descending);
10+
if (descending) { avx512_qsort<true>(arr, size, hasnan); }
11+
else {
12+
avx512_qsort<false>(arr, size, hasnan);
13+
}
1114
}
1215
template <>
1316
void qselect(_Float16 *arr,
@@ -16,7 +19,10 @@ namespace avx512 {
1619
bool hasnan,
1720
bool descending)
1821
{
19-
avx512_qselect(arr, k, arrsize, hasnan, descending);
22+
if (descending) { avx512_qselect<true>(arr, k, arrsize, hasnan); }
23+
else {
24+
avx512_qselect<false>(arr, k, arrsize, hasnan);
25+
}
2026
}
2127
template <>
2228
void partial_qsort(_Float16 *arr,
@@ -25,7 +31,10 @@ namespace avx512 {
2531
bool hasnan,
2632
bool descending)
2733
{
28-
avx512_partial_qsort(arr, k, arrsize, hasnan, descending);
34+
if (descending) { avx512_partial_qsort<true>(arr, k, arrsize, hasnan); }
35+
else {
36+
avx512_partial_qsort<false>(arr, k, arrsize, hasnan);
37+
}
2938
}
3039
} // namespace avx512
3140
} // namespace xss

src/avx512fp16-16bit-qsort.hpp

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -201,79 +201,64 @@ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr,
201201
}
202202
}
203203
/* Specialized template function for _Float16 qsort_*/
204-
template <>
204+
template <bool descending = false>
205205
X86_SIMD_SORT_INLINE_ONLY void
206-
avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending)
206+
avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan)
207207
{
208208
using vtype = zmm_vector<_Float16>;
209+
using comparator =
210+
typename std::conditional<descending,
211+
DescendingComparator<vtype>,
212+
AscendingComparator<vtype>>::type;
209213

210214
if (arrsize > 1) {
211215
arrsize_t nan_count = 0;
212216
if (UNLIKELY(hasnan)) {
213-
nan_count = replace_nan_with_inf<vtype, _Float16>(arr, arrsize);
214-
}
215-
if (descending) {
216-
qsort_<vtype, DescendingComparator<vtype>, _Float16>(
217-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
218-
}
219-
else {
220-
qsort_<vtype, AscendingComparator<vtype>, _Float16>(
221-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
217+
nan_count = replace_nan_with_inf<vtype>(arr, arrsize);
222218
}
219+
220+
qsort_<vtype, comparator, _Float16>(
221+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
222+
223223
replace_inf_with_nan(arr, arrsize, nan_count, descending);
224224
}
225225
}
226226

227-
template <>
228-
X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr,
229-
arrsize_t k,
230-
arrsize_t arrsize,
231-
bool hasnan,
232-
bool descending)
227+
template <bool descending = false>
228+
X86_SIMD_SORT_INLINE_ONLY void
229+
avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
233230
{
234231
using vtype = zmm_vector<_Float16>;
232+
using comparator =
233+
typename std::conditional<descending,
234+
DescendingComparator<vtype>,
235+
AscendingComparator<vtype>>::type;
235236

236-
if (descending) {
237-
arrsize_t index_first_elem = 0;
238-
if (UNLIKELY(hasnan)) {
237+
arrsize_t index_first_elem = 0;
238+
arrsize_t index_last_elem = arrsize - 1;
239+
240+
if (UNLIKELY(hasnan)) {
241+
if constexpr (descending) {
239242
index_first_elem = move_nans_to_start_of_array(arr, arrsize);
240243
}
241-
242-
arrsize_t size_without_nans = arrsize - index_first_elem;
243-
244-
if (index_first_elem <= k) {
245-
qselect_<vtype, DescendingComparator<vtype>, _Float16>(
246-
arr,
247-
k,
248-
index_first_elem,
249-
arrsize - 1,
250-
2 * (arrsize_t)log2(size_without_nans));
244+
else {
245+
index_last_elem = move_nans_to_end_of_array(arr, arrsize);
251246
}
252247
}
253-
else {
254-
arrsize_t indx_last_elem = arrsize - 1;
255-
if (UNLIKELY(hasnan)) {
256-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
257-
}
258248

259-
if (indx_last_elem >= k) {
260-
qselect_<vtype, AscendingComparator<vtype>, _Float16>(
261-
arr,
262-
k,
263-
0,
264-
indx_last_elem,
265-
2 * (arrsize_t)log2(indx_last_elem));
266-
}
249+
if (index_first_elem <= k && index_last_elem >= k) {
250+
qselect_<vtype, comparator, _Float16>(arr,
251+
k,
252+
index_first_elem,
253+
index_last_elem,
254+
2 * (arrsize_t)log2(arrsize));
267255
}
268256
}
269-
template <>
270-
X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr,
271-
arrsize_t k,
272-
arrsize_t arrsize,
273-
bool hasnan,
274-
bool descending)
257+
template <bool descending = false>
258+
X86_SIMD_SORT_INLINE_ONLY void
259+
avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
275260
{
276-
avx512_qselect(arr, k - 1, arrsize, hasnan, descending);
277-
avx512_qsort(arr, k - 1, hasnan, descending);
261+
avx512_qselect<descending>(arr, k - 1, arrsize, hasnan);
262+
avx512_qsort<descending>(arr, k - 1, hasnan);
278263
}
279264
#endif // AVX512FP16_QSORT_16BIT

src/xss-common-qsort.h

Lines changed: 58 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -612,91 +612,71 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
612612
}
613613

614614
// Quicksort routines:
615-
template <typename vtype, typename T>
616-
X86_SIMD_SORT_INLINE void
617-
xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending)
615+
template <typename vtype, typename T, bool descending = false>
616+
X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
618617
{
618+
using comparator =
619+
typename std::conditional<descending,
620+
DescendingComparator<vtype>,
621+
AscendingComparator<vtype>>::type;
622+
619623
if (arrsize > 1) {
624+
arrsize_t nan_count = 0;
620625
if constexpr (std::is_floating_point_v<T>) {
621-
arrsize_t nan_count = 0;
622626
if (UNLIKELY(hasnan)) {
623627
nan_count = replace_nan_with_inf<vtype>(arr, arrsize);
624628
}
625-
if (descending) {
626-
qsort_<vtype, DescendingComparator<vtype>, T>(
627-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
628-
}
629-
else {
630-
qsort_<vtype, AscendingComparator<vtype>, T>(
631-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
632-
}
633-
replace_inf_with_nan(arr, arrsize, nan_count, descending);
634-
}
635-
else {
636-
UNUSED(hasnan);
637-
if (descending) {
638-
qsort_<vtype, DescendingComparator<vtype>, T>(
639-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
640-
}
641-
else {
642-
qsort_<vtype, AscendingComparator<vtype>, T>(
643-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
644-
}
645629
}
630+
631+
UNUSED(hasnan);
632+
qsort_<vtype, comparator, T>(
633+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
634+
635+
replace_inf_with_nan(arr, arrsize, nan_count, descending);
646636
}
647637
}
648638

649639
// Quick select methods
650-
template <typename vtype, typename T>
651-
X86_SIMD_SORT_INLINE void xss_qselect(
652-
T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending)
640+
template <typename vtype, typename T, bool descending = false>
641+
X86_SIMD_SORT_INLINE void
642+
xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
653643
{
654-
if (descending) {
655-
arrsize_t index_first_elem = 0;
656-
if constexpr (std::is_floating_point_v<T>) {
657-
if (UNLIKELY(hasnan)) {
658-
index_first_elem = move_nans_to_start_of_array(arr, arrsize);
659-
}
660-
}
644+
using comparator =
645+
typename std::conditional<descending,
646+
DescendingComparator<vtype>,
647+
AscendingComparator<vtype>>::type;
661648

662-
arrsize_t size_without_nans = arrsize - index_first_elem;
649+
arrsize_t index_first_elem = 0;
650+
arrsize_t index_last_elem = arrsize - 1;
663651

664-
UNUSED(hasnan);
665-
if (index_first_elem <= k) {
666-
qselect_<vtype, DescendingComparator<vtype>, T>(
667-
arr,
668-
k,
669-
index_first_elem,
670-
arrsize - 1,
671-
2 * (arrsize_t)log2(size_without_nans));
672-
}
673-
}
674-
else {
675-
arrsize_t indx_last_elem = arrsize - 1;
676-
if constexpr (std::is_floating_point_v<T>) {
677-
if (UNLIKELY(hasnan)) {
678-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
652+
if constexpr (std::is_floating_point_v<T>) {
653+
if (UNLIKELY(hasnan)) {
654+
if constexpr (descending) {
655+
index_first_elem = move_nans_to_start_of_array(arr, arrsize);
656+
}
657+
else {
658+
index_last_elem = move_nans_to_end_of_array(arr, arrsize);
679659
}
680-
}
681-
UNUSED(hasnan);
682-
if (indx_last_elem >= k) {
683-
qselect_<vtype, AscendingComparator<vtype>, T>(
684-
arr,
685-
k,
686-
0,
687-
indx_last_elem,
688-
2 * (arrsize_t)log2(indx_last_elem));
689660
}
690661
}
662+
663+
UNUSED(hasnan);
664+
if (index_first_elem <= k && index_last_elem >= k) {
665+
qselect_<vtype, comparator, T>(arr,
666+
k,
667+
index_first_elem,
668+
index_last_elem,
669+
2 * (arrsize_t)log2(arrsize));
670+
}
691671
}
692672

693673
// Partial sort methods:
694-
template <typename vtype, typename T>
695-
X86_SIMD_SORT_INLINE void xss_partial_qsort(
696-
T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending)
674+
template <typename vtype, typename T, bool descending = false>
675+
X86_SIMD_SORT_INLINE void
676+
xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
697677
{
698-
xss_qselect<vtype, T>(arr, k - 1, arrsize, hasnan, descending);
699-
xss_qsort<vtype, T>(arr, k - 1, hasnan, descending);
678+
xss_qselect<vtype, T, descending>(arr, k - 1, arrsize, hasnan);
679+
xss_qsort<vtype, T, descending>(arr, k - 1, hasnan);
700680
}
701681

702682
#define DEFINE_METHODS(ISA, VTYPE) \
@@ -706,7 +686,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
706686
bool hasnan = false, \
707687
bool descending = false) \
708688
{ \
709-
xss_qsort<VTYPE, T>(arr, size, hasnan, descending); \
689+
if (descending) { xss_qsort<VTYPE, T, true>(arr, size, hasnan); } \
690+
else { \
691+
xss_qsort<VTYPE, T, false>(arr, size, hasnan); \
692+
} \
710693
} \
711694
template <typename T> \
712695
X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \
@@ -715,7 +698,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
715698
bool hasnan = false, \
716699
bool descending = false) \
717700
{ \
718-
xss_qselect<VTYPE, T>(arr, k, size, hasnan, descending); \
701+
if (descending) { xss_qselect<VTYPE, T, true>(arr, k, size, hasnan); } \
702+
else { \
703+
xss_qselect<VTYPE, T, false>(arr, k, size, hasnan); \
704+
} \
719705
} \
720706
template <typename T> \
721707
X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \
@@ -724,7 +710,12 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort(
724710
bool hasnan = false, \
725711
bool descending = false) \
726712
{ \
727-
xss_partial_qsort<VTYPE, T>(arr, k, size, hasnan, descending); \
713+
if (descending) { \
714+
xss_partial_qsort<VTYPE, T, true>(arr, k, size, hasnan); \
715+
} \
716+
else { \
717+
xss_partial_qsort<VTYPE, T, false>(arr, k, size, hasnan); \
718+
} \
728719
}
729720

730721
DEFINE_METHODS(avx512, zmm_vector<T>)

0 commit comments

Comments
 (0)