@@ -91,6 +91,17 @@ template <typename T, typename V> struct identity<T, intel::maximum<V>> {
91
91
static constexpr T value = std::numeric_limits<T>::lowest();
92
92
};
93
93
94
+ template <typename T>
95
+ using native_op_list =
96
+ type_list<intel::plus<T>, intel::bit_or<T>, intel::bit_xor<T>,
97
+ intel::bit_and<T>, intel::maximum<T>, intel::minimum<T>>;
98
+
99
+ template <typename T, typename BinaryOperation> struct is_native_op {
100
+ static constexpr bool value =
101
+ is_contained<BinaryOperation, native_op_list<T>>::value ||
102
+ is_contained<BinaryOperation, native_op_list<void >>::value;
103
+ };
104
+
94
105
template <typename Group, typename Ptr, class Function >
95
106
Function for_each (Group g, Ptr first, Ptr last, Function f) {
96
107
#ifdef __SYCL_DEVICE_ONLY__
@@ -114,6 +125,7 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) {
114
125
115
126
namespace intel {
116
127
128
+ // EnableIf shorthands for algorithms that depend only on type
117
129
template <typename T>
118
130
using EnableIfIsScalarArithmetic = cl::sycl::detail::enable_if_t <
119
131
cl::sycl::detail::is_scalar_arithmetic<T>::value, T>;
@@ -126,6 +138,28 @@ template <typename Ptr, typename T>
126
138
using EnableIfIsPointer =
127
139
cl::sycl::detail::enable_if_t <cl::sycl::detail::is_pointer<Ptr>::value, T>;
128
140
141
+ // EnableIf shorthands for algorithms that depend on type and an operator
142
+ template <typename T, typename BinaryOperation>
143
+ using EnableIfIsScalarArithmeticNativeOp = cl::sycl::detail::enable_if_t <
144
+ cl::sycl::detail::is_scalar_arithmetic<T>::value &&
145
+ cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
146
+ T>;
147
+
148
+ template <typename T, typename BinaryOperation>
149
+ using EnableIfIsVectorArithmeticNativeOp = cl::sycl::detail::enable_if_t <
150
+ cl::sycl::detail::is_vector_arithmetic<T>::value &&
151
+ cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
152
+ T>;
153
+
154
+ // TODO: Lift TriviallyCopyable restriction eventually
155
+ template <typename T, typename BinaryOperation>
156
+ using EnableIfIsNonNativeOp = cl::sycl::detail::enable_if_t <
157
+ (!cl::sycl::detail::is_scalar_arithmetic<T>::value &&
158
+ !cl::sycl::detail::is_vector_arithmetic<T>::value &&
159
+ std::is_trivially_copyable<T>::value) ||
160
+ !cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
161
+ T>;
162
+
129
163
template <typename Group> bool all_of (Group, bool pred) {
130
164
static_assert (sycl::detail::is_generic_group<Group>::value,
131
165
" Group algorithms only support the sycl::group and "
@@ -363,7 +397,8 @@ EnableIfIsVectorArithmetic<T> broadcast(Group g, T x) {
363
397
}
364
398
365
399
template <typename Group, typename T, class BinaryOperation >
366
- EnableIfIsScalarArithmetic<T> reduce (Group, T x, BinaryOperation binary_op) {
400
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
401
+ reduce (Group, T x, BinaryOperation binary_op) {
367
402
static_assert (sycl::detail::is_generic_group<Group>::value,
368
403
" Group algorithms only support the sycl::group and "
369
404
" intel::sub_group class." );
@@ -384,7 +419,8 @@ EnableIfIsScalarArithmetic<T> reduce(Group, T x, BinaryOperation binary_op) {
384
419
}
385
420
386
421
template <typename Group, typename T, class BinaryOperation >
387
- EnableIfIsVectorArithmetic<T> reduce (Group g, T x, BinaryOperation binary_op) {
422
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
423
+ reduce (Group g, T x, BinaryOperation binary_op) {
388
424
static_assert (sycl::detail::is_generic_group<Group>::value,
389
425
" Group algorithms only support the sycl::group and "
390
426
" intel::sub_group class." );
@@ -402,9 +438,25 @@ EnableIfIsVectorArithmetic<T> reduce(Group g, T x, BinaryOperation binary_op) {
402
438
return result;
403
439
}
404
440
441
+ template <typename Group, typename T, class BinaryOperation >
442
+ EnableIfIsNonNativeOp<T, BinaryOperation> reduce (Group g, T x,
443
+ BinaryOperation op) {
444
+ static_assert (sycl::detail::is_sub_group<Group>::value,
445
+ " reduce algorithm with user-defined types and operators"
446
+ " only supports intel::sub_group class." );
447
+ T result = x;
448
+ for (int mask = 1 ; mask < g.get_max_local_range ()[0 ]; mask *= 2 ) {
449
+ T tmp = g.shuffle_xor (result, id<1 >(mask));
450
+ if ((g.get_local_id ()[0 ] ^ mask) < g.get_local_range ()[0 ]) {
451
+ result = op (result, tmp);
452
+ }
453
+ }
454
+ return g.shuffle (result, 0 );
455
+ }
456
+
405
457
template <typename Group, typename V, typename T, class BinaryOperation >
406
- EnableIfIsScalarArithmetic<T> reduce (Group g, V x, T init,
407
- BinaryOperation binary_op) {
458
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
459
+ reduce (Group g, V x, T init, BinaryOperation binary_op) {
408
460
static_assert (sycl::detail::is_generic_group<Group>::value,
409
461
" Group algorithms only support the sycl::group and "
410
462
" intel::sub_group class." );
@@ -424,8 +476,8 @@ EnableIfIsScalarArithmetic<T> reduce(Group g, V x, T init,
424
476
}
425
477
426
478
template <typename Group, typename V, typename T, class BinaryOperation >
427
- EnableIfIsVectorArithmetic<T> reduce (Group g, V x, T init,
428
- BinaryOperation binary_op) {
479
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
480
+ reduce (Group g, V x, T init, BinaryOperation binary_op) {
429
481
static_assert (sycl::detail::is_generic_group<Group>::value,
430
482
" Group algorithms only support the sycl::group and "
431
483
" intel::sub_group class." );
@@ -449,6 +501,22 @@ EnableIfIsVectorArithmetic<T> reduce(Group g, V x, T init,
449
501
#endif
450
502
}
451
503
504
+ template <typename Group, typename V, typename T, class BinaryOperation >
505
+ EnableIfIsNonNativeOp<T, BinaryOperation> reduce (Group g, V x, T init,
506
+ BinaryOperation op) {
507
+ static_assert (sycl::detail::is_sub_group<Group>::value,
508
+ " reduce algorithm with user-defined types and operators"
509
+ " only supports intel::sub_group class." );
510
+ T result = x;
511
+ for (int mask = 1 ; mask < g.get_max_local_range ()[0 ]; mask *= 2 ) {
512
+ T tmp = g.shuffle_xor (result, id<1 >(mask));
513
+ if ((g.get_local_id ()[0 ] ^ mask) < g.get_local_range ()[0 ]) {
514
+ result = op (result, tmp);
515
+ }
516
+ }
517
+ return g.shuffle (op (init, result), 0 );
518
+ }
519
+
452
520
template <typename Group, typename Ptr, class BinaryOperation >
453
521
EnableIfIsPointer<Ptr, typename Ptr::element_type>
454
522
reduce (Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
@@ -509,8 +577,8 @@ EnableIfIsPointer<Ptr, T> reduce(Group g, Ptr first, Ptr last, T init,
509
577
}
510
578
511
579
template <typename Group, typename T, class BinaryOperation >
512
- EnableIfIsScalarArithmetic<T> exclusive_scan (Group, T x,
513
- BinaryOperation binary_op) {
580
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
581
+ exclusive_scan (Group, T x, BinaryOperation binary_op) {
514
582
static_assert (sycl::detail::is_generic_group<Group>::value,
515
583
" Group algorithms only support the sycl::group and "
516
584
" intel::sub_group class." );
@@ -530,8 +598,8 @@ EnableIfIsScalarArithmetic<T> exclusive_scan(Group, T x,
530
598
}
531
599
532
600
template <typename Group, typename T, class BinaryOperation >
533
- EnableIfIsVectorArithmetic<T> exclusive_scan (Group g, T x,
534
- BinaryOperation binary_op) {
601
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
602
+ exclusive_scan (Group g, T x, BinaryOperation binary_op) {
535
603
static_assert (sycl::detail::is_generic_group<Group>::value,
536
604
" Group algorithms only support the sycl::group and "
537
605
" intel::sub_group class." );
@@ -550,8 +618,8 @@ EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, T x,
550
618
}
551
619
552
620
template <typename Group, typename V, typename T, class BinaryOperation >
553
- EnableIfIsVectorArithmetic<T> exclusive_scan (Group g, V x, T init,
554
- BinaryOperation binary_op) {
621
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
622
+ exclusive_scan (Group g, V x, T init, BinaryOperation binary_op) {
555
623
static_assert (sycl::detail::is_generic_group<Group>::value,
556
624
" Group algorithms only support the sycl::group and "
557
625
" intel::sub_group class." );
@@ -570,8 +638,8 @@ EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, V x, T init,
570
638
}
571
639
572
640
template <typename Group, typename V, typename T, class BinaryOperation >
573
- EnableIfIsScalarArithmetic<T> exclusive_scan (Group g, V x, T init,
574
- BinaryOperation binary_op) {
641
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
642
+ exclusive_scan (Group g, V x, T init, BinaryOperation binary_op) {
575
643
static_assert (sycl::detail::is_generic_group<Group>::value,
576
644
" Group algorithms only support the sycl::group and "
577
645
" intel::sub_group class." );
@@ -663,8 +731,8 @@ EnableIfIsPointer<InPtr, OutPtr> exclusive_scan(Group g, InPtr first,
663
731
}
664
732
665
733
template <typename Group, typename T, class BinaryOperation >
666
- EnableIfIsVectorArithmetic<T> inclusive_scan (Group g, T x,
667
- BinaryOperation binary_op) {
734
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
735
+ inclusive_scan (Group g, T x, BinaryOperation binary_op) {
668
736
static_assert (sycl::detail::is_generic_group<Group>::value,
669
737
" Group algorithms only support the sycl::group and "
670
738
" intel::sub_group class." );
@@ -683,8 +751,8 @@ EnableIfIsVectorArithmetic<T> inclusive_scan(Group g, T x,
683
751
}
684
752
685
753
template <typename Group, typename T, class BinaryOperation >
686
- EnableIfIsScalarArithmetic<T> inclusive_scan (Group, T x,
687
- BinaryOperation binary_op) {
754
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
755
+ inclusive_scan (Group, T x, BinaryOperation binary_op) {
688
756
static_assert (sycl::detail::is_generic_group<Group>::value,
689
757
" Group algorithms only support the sycl::group and "
690
758
" intel::sub_group class." );
@@ -704,7 +772,7 @@ EnableIfIsScalarArithmetic<T> inclusive_scan(Group, T x,
704
772
}
705
773
706
774
template <typename Group, typename V, class BinaryOperation , typename T>
707
- EnableIfIsScalarArithmetic<T >
775
+ EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation >
708
776
inclusive_scan (Group g, V x, BinaryOperation binary_op, T init) {
709
777
static_assert (sycl::detail::is_generic_group<Group>::value,
710
778
" Group algorithms only support the sycl::group and "
@@ -727,7 +795,7 @@ inclusive_scan(Group g, V x, BinaryOperation binary_op, T init) {
727
795
}
728
796
729
797
template <typename Group, typename V, class BinaryOperation , typename T>
730
- EnableIfIsVectorArithmetic<T >
798
+ EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation >
731
799
inclusive_scan (Group g, V x, BinaryOperation binary_op, T init) {
732
800
static_assert (sycl::detail::is_generic_group<Group>::value,
733
801
" Group algorithms only support the sycl::group and "
0 commit comments