Skip to content

Commit 323f247

Browse files
author
Raghuveer Devulapalli
committed
Use scalar emulation of gather instruction for arg methods
1 parent 0890de5 commit 323f247

File tree

3 files changed

+152
-60
lines changed

3 files changed

+152
-60
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N)
8585
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
8686
argzmm_t argzmm1 = argtype::loadu(arg);
8787
argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
88-
zmm_t arrzmm1 = vtype::template i64gather<sizeof(type_t)>(argzmm1, arr);
88+
zmm_t arrzmm1 = vtype::i64gather(arr, arg);
8989
zmm_t arrzmm2 = vtype::template mask_i64gather<sizeof(type_t)>(
9090
vtype::zmm_max(), load_mask, argzmm2, arr);
9191
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
@@ -111,7 +111,7 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N)
111111
#pragma GCC unroll 2
112112
for (int ii = 0; ii < 2; ++ii) {
113113
argzmm[ii] = argtype::loadu(arg + 8 * ii);
114-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
114+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
115115
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
116116
}
117117

@@ -154,7 +154,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
154154
#pragma GCC unroll 4
155155
for (int ii = 0; ii < 4; ++ii) {
156156
argzmm[ii] = argtype::loadu(arg + 8 * ii);
157-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
157+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
158158
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
159159
}
160160

@@ -206,7 +206,7 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
206206
//#pragma GCC unroll 8
207207
// for (int ii = 0; ii < 8; ++ii) {
208208
// argzmm[ii] = argtype::loadu(arg + 8*ii);
209-
// arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
209+
// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
210210
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
211211
// }
212212
//
@@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr,
257257
// median of 8
258258
int64_t size = (right - left) / 8;
259259
using zmm_t = typename vtype::zmm_t;
260-
// TODO: Use gather here too:
261-
__m512i rand_index = _mm512_set_epi64(arg[left + size],
262-
arg[left + 2 * size],
263-
arg[left + 3 * size],
264-
arg[left + 4 * size],
265-
arg[left + 5 * size],
266-
arg[left + 6 * size],
267-
arg[left + 7 * size],
268-
arg[left + 8 * size]);
269-
zmm_t rand_vec
270-
= vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
260+
zmm_t rand_vec = vtype::set(arr[arg[left + size]],
261+
arr[arg[left + 2 * size]],
262+
arr[arg[left + 3 * size]],
263+
arr[arg[left + 4 * size]],
264+
arr[arg[left + 5 * size]],
265+
arr[arg[left + 6 * size]],
266+
arr[arg[left + 7 * size]],
267+
arr[arg[left + 8 * size]]);
271268
// pivot will never be a nan, since there are no nan's!
272269
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
273270
return ((type_t *)&sort)[4];

src/avx512-64bit-common.h

Lines changed: 128 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,22 @@ struct ymm_vector<float> {
3939
{
4040
return _mm256_set1_ps(type_max());
4141
}
42-
4342
static zmmi_t
4443
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
4544
{
4645
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
4746
}
47+
static zmm_t set(type_t v1,
48+
type_t v2,
49+
type_t v3,
50+
type_t v4,
51+
type_t v5,
52+
type_t v6,
53+
type_t v7,
54+
type_t v8)
55+
{
56+
return _mm256_set_ps(v1, v2, v3, v4, v5, v6, v7, v8);
57+
}
4858
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
4959
{
5060
return _kxor_mask8(x, y);
@@ -80,10 +90,16 @@ struct ymm_vector<float> {
8090
{
8191
return _mm512_mask_i64gather_ps(src, mask, index, base, scale);
8292
}
83-
template <int scale>
84-
static zmm_t i64gather(__m512i index, void const *base)
93+
static zmm_t i64gather(type_t *arr, int64_t *ind)
8594
{
86-
return _mm512_i64gather_ps(index, base, scale);
95+
return set(arr[ind[7]],
96+
arr[ind[6]],
97+
arr[ind[5]],
98+
arr[ind[4]],
99+
arr[ind[3]],
100+
arr[ind[2]],
101+
arr[ind[1]],
102+
arr[ind[0]]);
87103
}
88104
static zmm_t loadu(void const *mem)
89105
{
@@ -189,6 +205,17 @@ struct ymm_vector<uint32_t> {
189205
{
190206
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
191207
}
208+
static zmm_t set(type_t v1,
209+
type_t v2,
210+
type_t v3,
211+
type_t v4,
212+
type_t v5,
213+
type_t v6,
214+
type_t v7,
215+
type_t v8)
216+
{
217+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
218+
}
192219
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
193220
{
194221
return _kxor_mask8(x, y);
@@ -215,10 +242,16 @@ struct ymm_vector<uint32_t> {
215242
{
216243
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
217244
}
218-
template <int scale>
219-
static zmm_t i64gather(__m512i index, void const *base)
245+
static zmm_t i64gather(type_t *arr, int64_t *ind)
220246
{
221-
return _mm512_i64gather_epi32(index, base, scale);
247+
return set(arr[ind[7]],
248+
arr[ind[6]],
249+
arr[ind[5]],
250+
arr[ind[4]],
251+
arr[ind[3]],
252+
arr[ind[2]],
253+
arr[ind[1]],
254+
arr[ind[0]]);
222255
}
223256
static zmm_t loadu(void const *mem)
224257
{
@@ -318,6 +351,17 @@ struct ymm_vector<int32_t> {
318351
{
319352
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
320353
}
354+
static zmm_t set(type_t v1,
355+
type_t v2,
356+
type_t v3,
357+
type_t v4,
358+
type_t v5,
359+
type_t v6,
360+
type_t v7,
361+
type_t v8)
362+
{
363+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
364+
}
321365
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
322366
{
323367
return _kxor_mask8(x, y);
@@ -344,10 +388,16 @@ struct ymm_vector<int32_t> {
344388
{
345389
return _mm512_mask_i64gather_epi32(src, mask, index, base, scale);
346390
}
347-
template <int scale>
348-
static zmm_t i64gather(__m512i index, void const *base)
391+
static zmm_t i64gather(type_t *arr, int64_t *ind)
349392
{
350-
return _mm512_i64gather_epi32(index, base, scale);
393+
return set(arr[ind[7]],
394+
arr[ind[6]],
395+
arr[ind[5]],
396+
arr[ind[4]],
397+
arr[ind[3]],
398+
arr[ind[2]],
399+
arr[ind[1]],
400+
arr[ind[0]]);
351401
}
352402
static zmm_t loadu(void const *mem)
353403
{
@@ -448,6 +498,17 @@ struct zmm_vector<int64_t> {
448498
{
449499
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
450500
}
501+
static zmm_t set(type_t v1,
502+
type_t v2,
503+
type_t v3,
504+
type_t v4,
505+
type_t v5,
506+
type_t v6,
507+
type_t v7,
508+
type_t v8)
509+
{
510+
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
511+
}
451512
static opmask_t kxor_opmask(opmask_t x, opmask_t y)
452513
{
453514
return _kxor_mask8(x, y);
@@ -474,10 +535,16 @@ struct zmm_vector<int64_t> {
474535
{
475536
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
476537
}
477-
template <int scale>
478-
static zmm_t i64gather(__m512i index, void const *base)
538+
static zmm_t i64gather(type_t *arr, int64_t *ind)
479539
{
480-
return _mm512_i64gather_epi64(index, base, scale);
540+
return set(arr[ind[7]],
541+
arr[ind[6]],
542+
arr[ind[5]],
543+
arr[ind[4]],
544+
arr[ind[3]],
545+
arr[ind[2]],
546+
arr[ind[1]],
547+
arr[ind[0]]);
481548
}
482549
static zmm_t loadu(void const *mem)
483550
{
@@ -566,16 +633,33 @@ struct zmm_vector<uint64_t> {
566633
{
567634
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
568635
}
636+
static zmm_t set(type_t v1,
637+
type_t v2,
638+
type_t v3,
639+
type_t v4,
640+
type_t v5,
641+
type_t v6,
642+
type_t v7,
643+
type_t v8)
644+
{
645+
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
646+
}
569647
template <int scale>
570648
static zmm_t
571649
mask_i64gather(zmm_t src, opmask_t mask, __m512i index, void const *base)
572650
{
573651
return _mm512_mask_i64gather_epi64(src, mask, index, base, scale);
574652
}
575-
template <int scale>
576-
static zmm_t i64gather(__m512i index, void const *base)
653+
static zmm_t i64gather(type_t *arr, int64_t *ind)
577654
{
578-
return _mm512_i64gather_epi64(index, base, scale);
655+
return set(arr[ind[7]],
656+
arr[ind[6]],
657+
arr[ind[5]],
658+
arr[ind[4]],
659+
arr[ind[3]],
660+
arr[ind[2]],
661+
arr[ind[1]],
662+
arr[ind[0]]);
579663
}
580664
static opmask_t knot_opmask(opmask_t x)
581665
{
@@ -666,13 +750,22 @@ struct zmm_vector<double> {
666750
{
667751
return _mm512_set1_pd(type_max());
668752
}
669-
670753
static zmmi_t
671754
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
672755
{
673756
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
674757
}
675-
758+
static zmm_t set(type_t v1,
759+
type_t v2,
760+
type_t v3,
761+
type_t v4,
762+
type_t v5,
763+
type_t v6,
764+
type_t v7,
765+
type_t v8)
766+
{
767+
return _mm512_set_pd(v1, v2, v3, v4, v5, v6, v7, v8);
768+
}
676769
static zmm_t maskz_loadu(opmask_t mask, void const *mem)
677770
{
678771
return _mm512_maskz_loadu_pd(mask, mem);
@@ -704,10 +797,16 @@ struct zmm_vector<double> {
704797
{
705798
return _mm512_mask_i64gather_pd(src, mask, index, base, scale);
706799
}
707-
template <int scale>
708-
static zmm_t i64gather(__m512i index, void const *base)
800+
static zmm_t i64gather(type_t *arr, int64_t *ind)
709801
{
710-
return _mm512_i64gather_pd(index, base, scale);
802+
return set(arr[ind[7]],
803+
arr[ind[6]],
804+
arr[ind[5]],
805+
arr[ind[4]],
806+
arr[ind[3]],
807+
arr[ind[2]],
808+
arr[ind[1]],
809+
arr[ind[0]]);
711810
}
712811
static zmm_t loadu(void const *mem)
713812
{
@@ -794,15 +893,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
794893
// median of 8
795894
int64_t size = (right - left) / 8;
796895
using zmm_t = typename vtype::zmm_t;
797-
__m512i rand_index = _mm512_set_epi64(left + size,
798-
left + 2 * size,
799-
left + 3 * size,
800-
left + 4 * size,
801-
left + 5 * size,
802-
left + 6 * size,
803-
left + 7 * size,
804-
left + 8 * size);
805-
zmm_t rand_vec = vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
896+
zmm_t rand_vec = vtype::set(arr[left + size],
897+
arr[left + 2 * size],
898+
arr[left + 3 * size],
899+
arr[left + 4 * size],
900+
arr[left + 5 * size],
901+
arr[left + 6 * size],
902+
arr[left + 7 * size],
903+
arr[left + 8 * size]);
806904
// pivot will never be a nan, since there are no nan's!
807905
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
808906
return ((type_t *)&sort)[4];

src/avx512-common-argsort.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static inline int64_t partition_avx512(type_t *arr,
7575

7676
if (right - left == vtype::numlanes) {
7777
argzmm_t argvec = argtype::loadu(arg + left);
78-
zmm_t vec = vtype::template i64gather<sizeof(type_t)>(argvec, arr);
78+
zmm_t vec = vtype::i64gather(arr, arg + left);
7979
int32_t amount_gt_pivot = partition_vec<vtype>(arg,
8080
left,
8181
left + vtype::numlanes,
@@ -91,11 +91,9 @@ static inline int64_t partition_avx512(type_t *arr,
9191

9292
// first and last vtype::numlanes values are partitioned at the end
9393
argzmm_t argvec_left = argtype::loadu(arg + left);
94-
zmm_t vec_left
95-
= vtype::template i64gather<sizeof(type_t)>(argvec_left, arr);
94+
zmm_t vec_left = vtype::i64gather(arr, arg + left);
9695
argzmm_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes));
97-
zmm_t vec_right
98-
= vtype::template i64gather<sizeof(type_t)>(argvec_right, arr);
96+
zmm_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes));
9997
// store points of the vectors
10098
int64_t r_store = right - vtype::numlanes;
10199
int64_t l_store = left;
@@ -113,11 +111,11 @@ static inline int64_t partition_avx512(type_t *arr,
113111
if ((r_store + vtype::numlanes) - right < left - l_store) {
114112
right -= vtype::numlanes;
115113
arg_vec = argtype::loadu(arg + right);
116-
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
114+
curr_vec = vtype::i64gather(arr, arg + right);
117115
}
118116
else {
119117
arg_vec = argtype::loadu(arg + left);
120-
curr_vec = vtype::template i64gather<sizeof(type_t)>(arg_vec, arr);
118+
curr_vec = vtype::i64gather(arr, arg + left);
121119
left += vtype::numlanes;
122120
}
123121
// partition the current vector and save it on both sides of the array
@@ -201,12 +199,11 @@ static inline int64_t partition_avx512_unrolled(type_t *arr,
201199
#pragma GCC unroll 8
202200
for (int ii = 0; ii < num_unroll; ++ii) {
203201
argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii);
204-
vec_left[ii] = vtype::template i64gather<sizeof(type_t)>(
205-
argvec_left[ii], arr);
202+
vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii);
206203
argvec_right[ii] = argtype::loadu(
207204
arg + (right - vtype::numlanes * (num_unroll - ii)));
208-
vec_right[ii] = vtype::template i64gather<sizeof(type_t)>(
209-
argvec_right[ii], arr);
205+
vec_right[ii] = vtype::i64gather(
206+
arr, arg + (right - vtype::numlanes * (num_unroll - ii)));
210207
}
211208
// store points of the vectors
212209
int64_t r_store = right - vtype::numlanes;
@@ -228,16 +225,16 @@ static inline int64_t partition_avx512_unrolled(type_t *arr,
228225
for (int ii = 0; ii < num_unroll; ++ii) {
229226
arg_vec[ii]
230227
= argtype::loadu(arg + right + ii * vtype::numlanes);
231-
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
232-
arg_vec[ii], arr);
228+
curr_vec[ii] = vtype::i64gather(
229+
arr, arg + right + ii * vtype::numlanes);
233230
}
234231
}
235232
else {
236233
#pragma GCC unroll 8
237234
for (int ii = 0; ii < num_unroll; ++ii) {
238235
arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes);
239-
curr_vec[ii] = vtype::template i64gather<sizeof(type_t)>(
240-
arg_vec[ii], arr);
236+
curr_vec[ii] = vtype::i64gather(
237+
arr, arg + left + ii * vtype::numlanes);
241238
}
242239
left += num_unroll * vtype::numlanes;
243240
}

0 commit comments

Comments
 (0)