@@ -270,57 +270,21 @@ inline void argsort_64bit_(type_t *arr,
270270 argsort_64bit_<vtype>(arr, arg, pivot_index, right, max_iters - 1 );
271271}
272272
273- template <>
274- void avx512_argsort<double >(double *arr, int64_t *arg, int64_t arrsize)
275- {
276- if (arrsize > 1 ) {
277- argsort_64bit_<zmm_vector<double >, double >(
278- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
279- }
280- }
281-
282- template <>
283- std::vector<int64_t > avx512_argsort<double >(double *arr, int64_t arrsize)
284- {
285- std::vector<int64_t > indices (arrsize);
286- std::iota (indices.begin (), indices.end (), 0 );
287- avx512_argsort<double >(arr, indices.data (), arrsize);
288- return indices;
289- }
290-
291- template <>
292- void avx512_argsort<uint64_t >(uint64_t *arr, int64_t *arg, int64_t arrsize)
293- {
294- if (arrsize > 1 ) {
295- argsort_64bit_<zmm_vector<uint64_t >, uint64_t >(
296- arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
297- }
298- }
299-
300- template <>
301- std::vector<int64_t > avx512_argsort<uint64_t >(uint64_t *arr, int64_t arrsize)
302- {
303- std::vector<int64_t > indices (arrsize);
304- std::iota (indices.begin (), indices.end (), 0 );
305- avx512_argsort<uint64_t >(arr, indices.data (), arrsize);
306- return indices;
307- }
308-
309- template <>
310- void avx512_argsort<int64_t >(int64_t *arr, int64_t *arg, int64_t arrsize)
273+ template <typename T>
274+ void avx512_argsort (T* arr, int64_t *arg, int64_t arrsize)
311275{
312276 if (arrsize > 1 ) {
313- argsort_64bit_<zmm_vector<int64_t >, int64_t >(
277+ argsort_64bit_<zmm_vector<T> >(
314278 arr, arg, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
315279 }
316280}
317281
318- template <>
319- std::vector<int64_t > avx512_argsort< int64_t >( int64_t * arr, int64_t arrsize)
282+ template <typename T >
283+ std::vector<int64_t > avx512_argsort (T* arr, int64_t arrsize)
320284{
321285 std::vector<int64_t > indices (arrsize);
322286 std::iota (indices.begin (), indices.end (), 0 );
323- avx512_argsort<int64_t >(arr, indices.data (), arrsize);
287+ avx512_argsort<T >(arr, indices.data (), arrsize);
324288 return indices;
325289}
326290
0 commit comments