@@ -179,51 +179,101 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem)
179179}
180180
181181template <>
182- X86_SIMD_SORT_INLINE_ONLY void
183- replace_inf_with_nan (_Float16 *arr, arrsize_t size, arrsize_t nan_count)
182+ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan (_Float16 *arr,
183+ arrsize_t size,
184+ arrsize_t nan_count,
185+ bool descending)
184186{
185187 Fp16Bits val;
186188 val.i_ = 0x7c01 ;
187- for (arrsize_t ii = size - 1 ; nan_count > 0 ; --ii) {
188- arr[ii] = val.f_ ;
189- nan_count -= 1 ;
189+
190+ if (descending) {
191+ for (arrsize_t ii = 0 ; nan_count > 0 ; ++ii) {
192+ arr[ii] = val.f_ ;
193+ nan_count -= 1 ;
194+ }
195+ }
196+ else {
197+ for (arrsize_t ii = size - 1 ; nan_count > 0 ; --ii) {
198+ arr[ii] = val.f_ ;
199+ nan_count -= 1 ;
200+ }
190201 }
191202}
192203/* Specialized template function for _Float16 qsort_*/
193204template <>
194205X86_SIMD_SORT_INLINE_ONLY void
195- avx512_qsort (_Float16 *arr, arrsize_t arrsize, bool hasnan)
206+ avx512_qsort (_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending )
196207{
208+ using vtype = zmm_vector<_Float16>;
209+
197210 if (arrsize > 1 ) {
198211 arrsize_t nan_count = 0 ;
199212 if (UNLIKELY (hasnan)) {
200- nan_count = replace_nan_with_inf<zmm_vector<_Float16>, _Float16>(
201- arr, arrsize);
213+ nan_count = replace_nan_with_inf<vtype, _Float16>(arr, arrsize);
214+ }
215+ if (descending) {
216+ qsort_<vtype, DescendingComparator<vtype>, _Float16>(
217+ arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
218+ }
219+ else {
220+ qsort_<vtype, AscendingComparator<vtype>, _Float16>(
221+ arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
202222 }
203- qsort_<zmm_vector<_Float16>, _Float16>(
204- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize));
205- replace_inf_with_nan (arr, arrsize, nan_count);
223+ replace_inf_with_nan (arr, arrsize, nan_count, descending);
206224 }
207225}
208226
209227template <>
210- X86_SIMD_SORT_INLINE_ONLY void
211- avx512_qselect (_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
228+ X86_SIMD_SORT_INLINE_ONLY void avx512_qselect (_Float16 *arr,
229+ arrsize_t k,
230+ arrsize_t arrsize,
231+ bool hasnan,
232+ bool descending)
212233{
213- arrsize_t indx_last_elem = arrsize - 1 ;
214- if (UNLIKELY (hasnan)) {
215- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
234+ using vtype = zmm_vector<_Float16>;
235+
236+ if (descending) {
237+ arrsize_t index_first_elem = 0 ;
238+ if (UNLIKELY (hasnan)) {
239+ index_first_elem = move_nans_to_start_of_array (arr, arrsize);
240+ }
241+
242+ arrsize_t size_without_nans = arrsize - index_first_elem;
243+
244+ if (index_first_elem <= k) {
245+ qselect_<vtype, DescendingComparator<vtype>, _Float16>(
246+ arr,
247+ k,
248+ index_first_elem,
249+ arrsize - 1 ,
250+ 2 * (arrsize_t )log2 (size_without_nans));
251+ }
216252 }
217- if (indx_last_elem >= k) {
218- qselect_<zmm_vector<_Float16>, _Float16>(
219- arr, k, 0 , indx_last_elem, 2 * (arrsize_t )log2 (indx_last_elem));
253+ else {
254+ arrsize_t indx_last_elem = arrsize - 1 ;
255+ if (UNLIKELY (hasnan)) {
256+ indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
257+ }
258+
259+ if (indx_last_elem >= k) {
260+ qselect_<vtype, AscendingComparator<vtype>, _Float16>(
261+ arr,
262+ k,
263+ 0 ,
264+ indx_last_elem,
265+ 2 * (arrsize_t )log2 (indx_last_elem));
266+ }
220267 }
221268}
222269template <>
223- X86_SIMD_SORT_INLINE_ONLY void
224- avx512_partial_qsort (_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
270+ X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort (_Float16 *arr,
271+ arrsize_t k,
272+ arrsize_t arrsize,
273+ bool hasnan,
274+ bool descending)
225275{
226- avx512_qselect (arr, k - 1 , arrsize, hasnan);
227- avx512_qsort (arr, k - 1 , hasnan);
276+ avx512_qselect (arr, k - 1 , arrsize, hasnan, descending );
277+ avx512_qsort (arr, k - 1 , hasnan, descending );
228278}
229279#endif // AVX512FP16_QSORT_16BIT
0 commit comments