@@ -45,12 +45,22 @@ struct ymm_vector<float> {
4545 {
4646 return _mm256_set1_ps (type_max ());
4747 }
48-
4948 static zmmi_t
5049 seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
5150 {
5251 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
5352 }
53+ static reg_t set (type_t v1,
54+ type_t v2,
55+ type_t v3,
56+ type_t v4,
57+ type_t v5,
58+ type_t v6,
59+ type_t v7,
60+ type_t v8)
61+ {
62+ return _mm256_set_ps (v1, v2, v3, v4, v5, v6, v7, v8);
63+ }
5464 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
5565 {
5666 return _kxor_mask8 (x, y);
@@ -86,10 +96,16 @@ struct ymm_vector<float> {
8696 {
8797 return _mm512_mask_i64gather_ps (src, mask, index, base, scale);
8898 }
89- template <int scale>
90- static reg_t i64gather (__m512i index, void const *base)
99+ static reg_t i64gather (type_t *arr, int64_t *ind)
91100 {
92- return _mm512_i64gather_ps (index, base, scale);
101+ return set (arr[ind[7 ]],
102+ arr[ind[6 ]],
103+ arr[ind[5 ]],
104+ arr[ind[4 ]],
105+ arr[ind[3 ]],
106+ arr[ind[2 ]],
107+ arr[ind[1 ]],
108+ arr[ind[0 ]]);
93109 }
94110 static reg_t loadu (void const *mem)
95111 {
@@ -195,6 +211,17 @@ struct ymm_vector<uint32_t> {
195211 {
196212 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
197213 }
214+ static reg_t set (type_t v1,
215+ type_t v2,
216+ type_t v3,
217+ type_t v4,
218+ type_t v5,
219+ type_t v6,
220+ type_t v7,
221+ type_t v8)
222+ {
223+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
224+ }
198225 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
199226 {
200227 return _kxor_mask8 (x, y);
@@ -221,10 +248,16 @@ struct ymm_vector<uint32_t> {
221248 {
222249 return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
223250 }
224- template <int scale>
225- static reg_t i64gather (__m512i index, void const *base)
251+ static reg_t i64gather (type_t *arr, int64_t *ind)
226252 {
227- return _mm512_i64gather_epi32 (index, base, scale);
253+ return set (arr[ind[7 ]],
254+ arr[ind[6 ]],
255+ arr[ind[5 ]],
256+ arr[ind[4 ]],
257+ arr[ind[3 ]],
258+ arr[ind[2 ]],
259+ arr[ind[1 ]],
260+ arr[ind[0 ]]);
228261 }
229262 static reg_t loadu (void const *mem)
230263 {
@@ -324,6 +357,17 @@ struct ymm_vector<int32_t> {
324357 {
325358 return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
326359 }
360+ static reg_t set (type_t v1,
361+ type_t v2,
362+ type_t v3,
363+ type_t v4,
364+ type_t v5,
365+ type_t v6,
366+ type_t v7,
367+ type_t v8)
368+ {
369+ return _mm256_set_epi32 (v1, v2, v3, v4, v5, v6, v7, v8);
370+ }
327371 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
328372 {
329373 return _kxor_mask8 (x, y);
@@ -350,10 +394,16 @@ struct ymm_vector<int32_t> {
350394 {
351395 return _mm512_mask_i64gather_epi32 (src, mask, index, base, scale);
352396 }
353- template <int scale>
354- static reg_t i64gather (__m512i index, void const *base)
397+ static reg_t i64gather (type_t *arr, int64_t *ind)
355398 {
356- return _mm512_i64gather_epi32 (index, base, scale);
399+ return set (arr[ind[7 ]],
400+ arr[ind[6 ]],
401+ arr[ind[5 ]],
402+ arr[ind[4 ]],
403+ arr[ind[3 ]],
404+ arr[ind[2 ]],
405+ arr[ind[1 ]],
406+ arr[ind[0 ]]);
357407 }
358408 static reg_t loadu (void const *mem)
359409 {
@@ -456,6 +506,17 @@ struct zmm_vector<int64_t> {
456506 {
457507 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
458508 }
509+ static reg_t set (type_t v1,
510+ type_t v2,
511+ type_t v3,
512+ type_t v4,
513+ type_t v5,
514+ type_t v6,
515+ type_t v7,
516+ type_t v8)
517+ {
518+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
519+ }
459520 static opmask_t kxor_opmask (opmask_t x, opmask_t y)
460521 {
461522 return _kxor_mask8 (x, y);
@@ -482,10 +543,16 @@ struct zmm_vector<int64_t> {
482543 {
483544 return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
484545 }
485- template <int scale>
486- static reg_t i64gather (__m512i index, void const *base)
546+ static reg_t i64gather (type_t *arr, int64_t *ind)
487547 {
488- return _mm512_i64gather_epi64 (index, base, scale);
548+ return set (arr[ind[7 ]],
549+ arr[ind[6 ]],
550+ arr[ind[5 ]],
551+ arr[ind[4 ]],
552+ arr[ind[3 ]],
553+ arr[ind[2 ]],
554+ arr[ind[1 ]],
555+ arr[ind[0 ]]);
489556 }
490557 static reg_t loadu (void const *mem)
491558 {
@@ -589,16 +656,33 @@ struct zmm_vector<uint64_t> {
589656 {
590657 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
591658 }
659+ static reg_t set (type_t v1,
660+ type_t v2,
661+ type_t v3,
662+ type_t v4,
663+ type_t v5,
664+ type_t v6,
665+ type_t v7,
666+ type_t v8)
667+ {
668+ return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
669+ }
592670 template <int scale>
593671 static reg_t
594672 mask_i64gather (reg_t src, opmask_t mask, __m512i index, void const *base)
595673 {
596674 return _mm512_mask_i64gather_epi64 (src, mask, index, base, scale);
597675 }
598- template <int scale>
599- static reg_t i64gather (__m512i index, void const *base)
676+ static reg_t i64gather (type_t *arr, int64_t *ind)
600677 {
601- return _mm512_i64gather_epi64 (index, base, scale);
678+ return set (arr[ind[7 ]],
679+ arr[ind[6 ]],
680+ arr[ind[5 ]],
681+ arr[ind[4 ]],
682+ arr[ind[3 ]],
683+ arr[ind[2 ]],
684+ arr[ind[1 ]],
685+ arr[ind[0 ]]);
602686 }
603687 static opmask_t knot_opmask (opmask_t x)
604688 {
@@ -704,13 +788,22 @@ struct zmm_vector<double> {
704788 {
705789 return _mm512_set1_pd (type_max ());
706790 }
707-
708791 static zmmi_t
709792 seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
710793 {
711794 return _mm512_set_epi64 (v1, v2, v3, v4, v5, v6, v7, v8);
712795 }
713-
796+ static reg_t set (type_t v1,
797+ type_t v2,
798+ type_t v3,
799+ type_t v4,
800+ type_t v5,
801+ type_t v6,
802+ type_t v7,
803+ type_t v8)
804+ {
805+ return _mm512_set_pd (v1, v2, v3, v4, v5, v6, v7, v8);
806+ }
714807 static reg_t maskz_loadu (opmask_t mask, void const *mem)
715808 {
716809 return _mm512_maskz_loadu_pd (mask, mem);
@@ -742,10 +835,16 @@ struct zmm_vector<double> {
742835 {
743836 return _mm512_mask_i64gather_pd (src, mask, index, base, scale);
744837 }
745- template <int scale>
746- static reg_t i64gather (__m512i index, void const *base)
838+ static reg_t i64gather (type_t *arr, int64_t *ind)
747839 {
748- return _mm512_i64gather_pd (index, base, scale);
840+ return set (arr[ind[7 ]],
841+ arr[ind[6 ]],
842+ arr[ind[5 ]],
843+ arr[ind[4 ]],
844+ arr[ind[3 ]],
845+ arr[ind[2 ]],
846+ arr[ind[1 ]],
847+ arr[ind[0 ]]);
749848 }
750849 static reg_t loadu (void const *mem)
751850 {
@@ -841,7 +940,6 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm)
841940template <typename vtype, typename reg_t = typename vtype::reg_t >
842941X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_64bit (reg_t zmm)
843942{
844-
845943 // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
846944 zmm = cmp_merge<vtype>(
847945 zmm,
0 commit comments