Skip to content

Commit 70c4bce

Browse files
committed
Fixed fp16 code issues
1 parent a0a929b commit 70c4bce

File tree

1 file changed

+73
-23
lines changed

1 file changed

+73
-23
lines changed

src/avx512fp16-16bit-qsort.hpp

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -179,51 +179,101 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem)
179179
}
180180

181181
template <>
182-
X86_SIMD_SORT_INLINE_ONLY void
183-
replace_inf_with_nan(_Float16 *arr, arrsize_t size, arrsize_t nan_count)
182+
X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr,
183+
arrsize_t size,
184+
arrsize_t nan_count,
185+
bool descending)
184186
{
185187
Fp16Bits val;
186188
val.i_ = 0x7c01;
187-
for (arrsize_t ii = size - 1; nan_count > 0; --ii) {
188-
arr[ii] = val.f_;
189-
nan_count -= 1;
189+
190+
if (descending) {
191+
for (arrsize_t ii = 0; nan_count > 0; ++ii) {
192+
arr[ii] = val.f_;
193+
nan_count -= 1;
194+
}
195+
}
196+
else {
197+
for (arrsize_t ii = size - 1; nan_count > 0; --ii) {
198+
arr[ii] = val.f_;
199+
nan_count -= 1;
200+
}
190201
}
191202
}
192203
/* Specialized template function for _Float16 qsort_*/
193204
template <>
194205
X86_SIMD_SORT_INLINE_ONLY void
195-
avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan)
206+
avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending)
196207
{
208+
using vtype = zmm_vector<_Float16>;
209+
197210
if (arrsize > 1) {
198211
arrsize_t nan_count = 0;
199212
if (UNLIKELY(hasnan)) {
200-
nan_count = replace_nan_with_inf<zmm_vector<_Float16>, _Float16>(
201-
arr, arrsize);
213+
nan_count = replace_nan_with_inf<vtype, _Float16>(arr, arrsize);
214+
}
215+
if (descending) {
216+
qsort_<vtype, DescendingComparator<vtype>, _Float16>(
217+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
218+
}
219+
else {
220+
qsort_<vtype, AscendingComparator<vtype>, _Float16>(
221+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
202222
}
203-
qsort_<zmm_vector<_Float16>, _Float16>(
204-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
205-
replace_inf_with_nan(arr, arrsize, nan_count);
223+
replace_inf_with_nan(arr, arrsize, nan_count, descending);
206224
}
207225
}
208226

209227
template <>
210-
X86_SIMD_SORT_INLINE_ONLY void
211-
avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
228+
X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr,
229+
arrsize_t k,
230+
arrsize_t arrsize,
231+
bool hasnan,
232+
bool descending)
212233
{
213-
arrsize_t indx_last_elem = arrsize - 1;
214-
if (UNLIKELY(hasnan)) {
215-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
234+
using vtype = zmm_vector<_Float16>;
235+
236+
if (descending) {
237+
arrsize_t index_first_elem = 0;
238+
if (UNLIKELY(hasnan)) {
239+
index_first_elem = move_nans_to_start_of_array(arr, arrsize);
240+
}
241+
242+
arrsize_t size_without_nans = arrsize - index_first_elem;
243+
244+
if (index_first_elem <= k) {
245+
qselect_<vtype, DescendingComparator<vtype>, _Float16>(
246+
arr,
247+
k,
248+
index_first_elem,
249+
arrsize - 1,
250+
2 * (arrsize_t)log2(size_without_nans));
251+
}
216252
}
217-
if (indx_last_elem >= k) {
218-
qselect_<zmm_vector<_Float16>, _Float16>(
219-
arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem));
253+
else {
254+
arrsize_t indx_last_elem = arrsize - 1;
255+
if (UNLIKELY(hasnan)) {
256+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
257+
}
258+
259+
if (indx_last_elem >= k) {
260+
qselect_<vtype, AscendingComparator<vtype>, _Float16>(
261+
arr,
262+
k,
263+
0,
264+
indx_last_elem,
265+
2 * (arrsize_t)log2(indx_last_elem));
266+
}
220267
}
221268
}
222269
template <>
223-
X86_SIMD_SORT_INLINE_ONLY void
224-
avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
270+
X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr,
271+
arrsize_t k,
272+
arrsize_t arrsize,
273+
bool hasnan,
274+
bool descending)
225275
{
226-
avx512_qselect(arr, k - 1, arrsize, hasnan);
227-
avx512_qsort(arr, k - 1, hasnan);
276+
avx512_qselect(arr, k - 1, arrsize, hasnan, descending);
277+
avx512_qsort(arr, k - 1, hasnan, descending);
228278
}
229279
#endif // AVX512FP16_QSORT_16BIT

0 commit comments

Comments
 (0)