@@ -39,12 +39,22 @@ struct ymm_vector<float> {
3939 {
4040 return _mm256_set1_ps (type_max ());
4141 }
42-
4342 static zmmi_t
4443 seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
4544 {
4645 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
4746 }
47+ static zmm_t set (type_t v1,
48+ type_t v2,
49+ type_t v3,
50+ type_t v4,
51+ type_t v5,
52+ type_t v6,
53+ type_t v7,
54+ type_t v8)
55+ {
56+ return _mm256_set_ps (v1, v2, v3, v4, v5, v6, v7, v8);
57+ }
4858 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
4959 {
5060 return _kxor_mask8 (x, y);
@@ -80,10 +90,16 @@ struct ymm_vector<float> {
8090 {
8191 return _mm512_mask_i64gather_ps (src, mask, index, base, scale);
8292 }
83- template <int scale>
84- static zmm_t i64gather (__m512i index, void const *base)
93+ static zmm_t i64gather (type_t *arr, int64_t *ind)
8594 {
86- return _mm512_i64gather_ps (index, base, scale);
95+ return set (arr[ind[7 ]],
96+ arr[ind[6 ]],
97+ arr[ind[5 ]],
98+ arr[ind[4 ]],
99+ arr[ind[3 ]],
100+ arr[ind[2 ]],
101+ arr[ind[1 ]],
102+ arr[ind[0 ]]);
87103 }
88104 static zmm_t loadu (void const *mem)
89105 {
@@ -189,6 +205,17 @@ struct ymm_vector<uint32_t> {
189205 {
190206 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
191207 }
208+ static zmm_t set (type_t v1,
209+ type_t v2,
210+ type_t v3,
211+ type_t v4,
212+ type_t v5,
213+ type_t v6,
214+ type_t v7,
215+ type_t v8)
216+ {
217+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
218+ }
192219 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
193220 {
194221 return _kxor_mask8 (x, y);
@@ -215,10 +242,16 @@ struct ymm_vector<uint32_t> {
215242 {
216243 return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
217244 }
218- template <int scale>
219- static zmm_t i64gather (__m512i index, void const *base)
245+ static zmm_t i64gather (type_t *arr, int64_t *ind)
220246 {
221- return _mm512_i64gather_epi32 (index, base, scale);
247+ return set (arr[ind[7 ]],
248+ arr[ind[6 ]],
249+ arr[ind[5 ]],
250+ arr[ind[4 ]],
251+ arr[ind[3 ]],
252+ arr[ind[2 ]],
253+ arr[ind[1 ]],
254+ arr[ind[0 ]]);
222255 }
223256 static zmm_t loadu (void const *mem)
224257 {
@@ -318,6 +351,17 @@ struct ymm_vector<int32_t> {
318351 {
319352 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
320353 }
354+ static zmm_t set (type_t v1,
355+ type_t v2,
356+ type_t v3,
357+ type_t v4,
358+ type_t v5,
359+ type_t v6,
360+ type_t v7,
361+ type_t v8)
362+ {
363+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
364+ }
321365 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
322366 {
323367 return _kxor_mask8 (x, y);
@@ -344,10 +388,16 @@ struct ymm_vector<int32_t> {
344388 {
345389 return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
346390 }
347- template <int scale>
348- static zmm_t i64gather (__m512i index, void const *base)
391+ static zmm_t i64gather (type_t *arr, int64_t *ind)
349392 {
350- return _mm512_i64gather_epi32 (index, base, scale);
393+ return set (arr[ind[7 ]],
394+ arr[ind[6 ]],
395+ arr[ind[5 ]],
396+ arr[ind[4 ]],
397+ arr[ind[3 ]],
398+ arr[ind[2 ]],
399+ arr[ind[1 ]],
400+ arr[ind[0 ]]);
351401 }
352402 static zmm_t loadu (void const *mem)
353403 {
@@ -448,6 +498,17 @@ struct zmm_vector<int64_t> {
448498 {
449499 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
450500 }
501+ static zmm_t set (type_t v1,
502+ type_t v2,
503+ type_t v3,
504+ type_t v4,
505+ type_t v5,
506+ type_t v6,
507+ type_t v7,
508+ type_t v8)
509+ {
510+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
511+ }
451512 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
452513 {
453514 return _kxor_mask8 (x, y);
@@ -474,10 +535,16 @@ struct zmm_vector<int64_t> {
474535 {
475536 return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
476537 }
477- template <int scale>
478- static zmm_t i64gather (__m512i index, void const *base)
538+ static zmm_t i64gather (type_t *arr, int64_t *ind)
479539 {
480- return _mm512_i64gather_epi64 (index, base, scale);
540+ return set (arr[ind[7 ]],
541+ arr[ind[6 ]],
542+ arr[ind[5 ]],
543+ arr[ind[4 ]],
544+ arr[ind[3 ]],
545+ arr[ind[2 ]],
546+ arr[ind[1 ]],
547+ arr[ind[0 ]]);
481548 }
482549 static zmm_t loadu (void const *mem)
483550 {
@@ -566,16 +633,33 @@ struct zmm_vector<uint64_t> {
566633 {
567634 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
568635 }
636+ static zmm_t set (type_t v1,
637+ type_t v2,
638+ type_t v3,
639+ type_t v4,
640+ type_t v5,
641+ type_t v6,
642+ type_t v7,
643+ type_t v8)
644+ {
645+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
646+ }
569647 template <int scale>
570648 static zmm_t
571649 mask_i64gather (zmm_t src, opmask_t mask, __m512i index, void const *base)
572650 {
573651 return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
574652 }
575- template <int scale>
576- static zmm_t i64gather (__m512i index, void const *base)
653+ static zmm_t i64gather (type_t *arr, int64_t *ind)
577654 {
578- return _mm512_i64gather_epi64 (index, base, scale);
655+ return set (arr[ind[7 ]],
656+ arr[ind[6 ]],
657+ arr[ind[5 ]],
658+ arr[ind[4 ]],
659+ arr[ind[3 ]],
660+ arr[ind[2 ]],
661+ arr[ind[1 ]],
662+ arr[ind[0 ]]);
579663 }
580664 static opmask_t knot_opmask (opmask_t x)
581665 {
@@ -666,13 +750,22 @@ struct zmm_vector<double> {
666750 {
667751 return _mm512_set1_pd (type_max ());
668752 }
669-
670753 static zmmi_t
671754 seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
672755 {
673756 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
674757 }
675-
758+ static zmm_t set (type_t v1,
759+ type_t v2,
760+ type_t v3,
761+ type_t v4,
762+ type_t v5,
763+ type_t v6,
764+ type_t v7,
765+ type_t v8)
766+ {
767+ return _mm512_set_pd (v1, v2, v3, v4, v5, v6, v7, v8);
768+ }
676769 static zmm_t maskz_loadu (opmask_t mask, void const *mem)
677770 {
678771 return _mm512_maskz_loadu_pd (mask, mem);
@@ -704,10 +797,16 @@ struct zmm_vector<double> {
704797 {
705798 return _mm512_mask_i64gather_pd (src, mask, index, base, scale);
706799 }
707- template <int scale>
708- static zmm_t i64gather (__m512i index, void const *base)
800+ static zmm_t i64gather (type_t *arr, int64_t *ind)
709801 {
710- return _mm512_i64gather_pd (index, base, scale);
802+ return set (arr[ind[7 ]],
803+ arr[ind[6 ]],
804+ arr[ind[5 ]],
805+ arr[ind[4 ]],
806+ arr[ind[3 ]],
807+ arr[ind[2 ]],
808+ arr[ind[1 ]],
809+ arr[ind[0 ]]);
711810 }
712811 static zmm_t loadu (void const *mem)
713812 {
@@ -794,15 +893,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
794893 // median of 8
795894 int64_t size = (right - left) / 8 ;
796895 using zmm_t = typename vtype::zmm_t ;
797- __m512i rand_index = _mm512_set_epi64 (left + size,
798- left + 2 * size,
799- left + 3 * size,
800- left + 4 * size,
801- left + 5 * size,
802- left + 6 * size,
803- left + 7 * size,
804- left + 8 * size);
805- zmm_t rand_vec = vtype::template i64gather<sizeof (type_t )>(rand_index, arr);
896+ zmm_t rand_vec = vtype::set (arr[left + size],
897+ arr[left + 2 * size],
898+ arr[left + 3 * size],
899+ arr[left + 4 * size],
900+ arr[left + 5 * size],
901+ arr[left + 6 * size],
902+ arr[left + 7 * size],
903+ arr[left + 8 * size]);
806904 // pivot will never be a nan, since there are no nan's!
807905 zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
808906 return ((type_t *)&sort)[4 ];
0 commit comments