@@ -256,6 +256,15 @@ struct zmm_vector<float> {
256256 {
257257 return _mm512_cmp_ps_mask (x, y, _CMP_GE_OQ);
258258 }
259+ static opmask_t get_partial_loadmask (int size)
260+ {
261+ return (0x0001 << size) - 0x0001 ;
262+ }
263+ template <int type>
264+ static opmask_t fpclass (zmm_t x)
265+ {
266+ return _mm512_fpclass_ps_mask (x, type);
267+ }
259268 template <int scale>
260269 static ymm_t i64gather (__m512i index, void const *base)
261270 {
@@ -279,6 +288,10 @@ struct zmm_vector<float> {
279288 {
280289 return _mm512_mask_compressstoreu_ps (mem, mask, x);
281290 }
291+ static zmm_t maskz_loadu (opmask_t mask, void const *mem)
292+ {
293+ return _mm512_maskz_loadu_ps (mask, mem);
294+ }
282295 static zmm_t mask_loadu (zmm_t x, opmask_t mask, void const *mem)
283296 {
284297 return _mm512_mask_loadu_ps (x, mask, mem);
@@ -689,95 +702,53 @@ static void qselect_32bit_(type_t *arr,
689702 qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
690703}
691704
692- X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf (float *arr, int64_t arrsize)
693- {
694- int64_t nan_count = 0 ;
695- __mmask16 loadmask = 0xFFFF ;
696- while (arrsize > 0 ) {
697- if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
698- __m512 in_zmm = _mm512_maskz_loadu_ps (loadmask, arr);
699- __mmask16 nanmask = _mm512_cmp_ps_mask (in_zmm, in_zmm, _CMP_NEQ_UQ);
700- nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
701- _mm512_mask_storeu_ps (arr, nanmask, ZMM_MAX_FLOAT);
702- arr += 16 ;
703- arrsize -= 16 ;
704- }
705- return nan_count;
706- }
707-
708- X86_SIMD_SORT_INLINE void
709- replace_inf_with_nan (float *arr, int64_t arrsize, int64_t nan_count)
710- {
711- for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
712- arr[ii] = std::nanf (" 1" );
713- nan_count -= 1 ;
714- }
715- }
716-
705+ /* Specialized template function for 32-bit qselect_ funcs*/
717706template <>
718- void avx512_qselect<int32_t >(int32_t *arr,
719- int64_t k,
720- int64_t arrsize,
721- bool hasnan)
707+ void qselect_<zmm_vector<int32_t >>(
708+ int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
722709{
723- if (arrsize > 1 ) {
724- qselect_32bit_<zmm_vector<int32_t >, int32_t >(
725- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
726- }
710+ qselect_32bit_<zmm_vector<int32_t >>(arr, k, left, right, maxiters);
727711}
728712
729713template <>
730- void avx512_qselect<uint32_t >(uint32_t *arr,
731- int64_t k,
732- int64_t arrsize,
733- bool hasnan)
714+ void qselect_<zmm_vector<uint32_t >>(
715+ uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
734716{
735- if (arrsize > 1 ) {
736- qselect_32bit_<zmm_vector<uint32_t >, uint32_t >(
737- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
738- }
717+ qselect_32bit_<zmm_vector<uint32_t >>(arr, k, left, right, maxiters);
739718}
740719
741720template <>
742- void avx512_qselect<float >(float *arr, int64_t k, int64_t arrsize, bool hasnan)
721+ void qselect_<zmm_vector<float >>(
722+ float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
743723{
744- int64_t indx_last_elem = arrsize - 1 ;
745- if (UNLIKELY (hasnan)) {
746- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
747- }
748- if (indx_last_elem >= k) {
749- qselect_32bit_<zmm_vector<float >, float >(
750- arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
751- }
724+ qselect_32bit_<zmm_vector<float >>(arr, k, left, right, maxiters);
752725}
753726
727+ /* Specialized template function for 32-bit qsort_ funcs*/
754728template <>
755- void avx512_qsort<int32_t >(int32_t *arr, int64_t arrsize)
729+ void qsort_<zmm_vector<int32_t >>(int32_t *arr,
730+ int64_t left,
731+ int64_t right,
732+ int64_t maxiters)
756733{
757- if (arrsize > 1 ) {
758- qsort_32bit_<zmm_vector<int32_t >, int32_t >(
759- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
760- }
734+ qsort_32bit_<zmm_vector<int32_t >>(arr, left, right, maxiters);
761735}
762736
763737template <>
764- void avx512_qsort<uint32_t >(uint32_t *arr, int64_t arrsize)
738+ void qsort_<zmm_vector<uint32_t >>(uint32_t *arr,
739+ int64_t left,
740+ int64_t right,
741+ int64_t maxiters)
765742{
766- if (arrsize > 1 ) {
767- qsort_32bit_<zmm_vector<uint32_t >, uint32_t >(
768- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
769- }
743+ qsort_32bit_<zmm_vector<uint32_t >>(arr, left, right, maxiters);
770744}
771745
772746template <>
773- void avx512_qsort<float >(float *arr, int64_t arrsize)
747+ void qsort_<zmm_vector<float >>(float *arr,
748+ int64_t left,
749+ int64_t right,
750+ int64_t maxiters)
774751{
775- if (arrsize > 1 ) {
776- int64_t nan_count = replace_nan_with_inf (arr, arrsize);
777- qsort_32bit_<zmm_vector<float >, float >(
778- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
779- replace_inf_with_nan (arr, arrsize, nan_count);
780- }
752+ qsort_32bit_<zmm_vector<float >>(arr, left, right, maxiters);
781753}
782-
783754#endif // AVX512_QSORT_32BIT
0 commit comments