Skip to content

Commit 728429a

Browse files
authored
[SYCL] Enable user-defined sub-group reductions (#2106)
Restricted to TriviallyCopyable types for now. Signed-off-by: John Pennycook <[email protected]> --- This is a partial implementation of the functionality proposed by @psteinbrecher in #1947, enabling custom sub-group reductions for trivially copyable types. I cannot see any good way to implement user-defined reductions using `group` instead of `sub_group` today, or to extend support for arbitrary types, but this is a step in the right direction. Enabling the other algorithms for user-defined types and operators (with the same trivially copyable restriction) could follow the same approach. Starting with reductions in this PR because they are the most common request.
1 parent ba25952 commit 728429a

File tree

2 files changed

+183
-20
lines changed

2 files changed

+183
-20
lines changed

sycl/include/CL/sycl/intel/group_algorithm.hpp

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ template <typename T, typename V> struct identity<T, intel::maximum<V>> {
9191
static constexpr T value = std::numeric_limits<T>::lowest();
9292
};
9393

94+
template <typename T>
95+
using native_op_list =
96+
type_list<intel::plus<T>, intel::bit_or<T>, intel::bit_xor<T>,
97+
intel::bit_and<T>, intel::maximum<T>, intel::minimum<T>>;
98+
99+
template <typename T, typename BinaryOperation> struct is_native_op {
100+
static constexpr bool value =
101+
is_contained<BinaryOperation, native_op_list<T>>::value ||
102+
is_contained<BinaryOperation, native_op_list<void>>::value;
103+
};
104+
94105
template <typename Group, typename Ptr, class Function>
95106
Function for_each(Group g, Ptr first, Ptr last, Function f) {
96107
#ifdef __SYCL_DEVICE_ONLY__
@@ -114,6 +125,7 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) {
114125

115126
namespace intel {
116127

128+
// EnableIf shorthands for algorithms that depend only on type
117129
template <typename T>
118130
using EnableIfIsScalarArithmetic = cl::sycl::detail::enable_if_t<
119131
cl::sycl::detail::is_scalar_arithmetic<T>::value, T>;
@@ -126,6 +138,28 @@ template <typename Ptr, typename T>
126138
using EnableIfIsPointer =
127139
cl::sycl::detail::enable_if_t<cl::sycl::detail::is_pointer<Ptr>::value, T>;
128140

141+
// EnableIf shorthands for algorithms that depend on type and an operator
142+
template <typename T, typename BinaryOperation>
143+
using EnableIfIsScalarArithmeticNativeOp = cl::sycl::detail::enable_if_t<
144+
cl::sycl::detail::is_scalar_arithmetic<T>::value &&
145+
cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
146+
T>;
147+
148+
template <typename T, typename BinaryOperation>
149+
using EnableIfIsVectorArithmeticNativeOp = cl::sycl::detail::enable_if_t<
150+
cl::sycl::detail::is_vector_arithmetic<T>::value &&
151+
cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
152+
T>;
153+
154+
// TODO: Lift TriviallyCopyable restriction eventually
155+
template <typename T, typename BinaryOperation>
156+
using EnableIfIsNonNativeOp = cl::sycl::detail::enable_if_t<
157+
(!cl::sycl::detail::is_scalar_arithmetic<T>::value &&
158+
!cl::sycl::detail::is_vector_arithmetic<T>::value &&
159+
std::is_trivially_copyable<T>::value) ||
160+
!cl::sycl::detail::is_native_op<T, BinaryOperation>::value,
161+
T>;
162+
129163
template <typename Group> bool all_of(Group, bool pred) {
130164
static_assert(sycl::detail::is_generic_group<Group>::value,
131165
"Group algorithms only support the sycl::group and "
@@ -363,7 +397,8 @@ EnableIfIsVectorArithmetic<T> broadcast(Group g, T x) {
363397
}
364398

365399
template <typename Group, typename T, class BinaryOperation>
366-
EnableIfIsScalarArithmetic<T> reduce(Group, T x, BinaryOperation binary_op) {
400+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
401+
reduce(Group, T x, BinaryOperation binary_op) {
367402
static_assert(sycl::detail::is_generic_group<Group>::value,
368403
"Group algorithms only support the sycl::group and "
369404
"intel::sub_group class.");
@@ -384,7 +419,8 @@ EnableIfIsScalarArithmetic<T> reduce(Group, T x, BinaryOperation binary_op) {
384419
}
385420

386421
template <typename Group, typename T, class BinaryOperation>
387-
EnableIfIsVectorArithmetic<T> reduce(Group g, T x, BinaryOperation binary_op) {
422+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
423+
reduce(Group g, T x, BinaryOperation binary_op) {
388424
static_assert(sycl::detail::is_generic_group<Group>::value,
389425
"Group algorithms only support the sycl::group and "
390426
"intel::sub_group class.");
@@ -402,9 +438,25 @@ EnableIfIsVectorArithmetic<T> reduce(Group g, T x, BinaryOperation binary_op) {
402438
return result;
403439
}
404440

441+
template <typename Group, typename T, class BinaryOperation>
442+
EnableIfIsNonNativeOp<T, BinaryOperation> reduce(Group g, T x,
443+
BinaryOperation op) {
444+
static_assert(sycl::detail::is_sub_group<Group>::value,
445+
"reduce algorithm with user-defined types and operators"
446+
"only supports intel::sub_group class.");
447+
T result = x;
448+
for (int mask = 1; mask < g.get_max_local_range()[0]; mask *= 2) {
449+
T tmp = g.shuffle_xor(result, id<1>(mask));
450+
if ((g.get_local_id()[0] ^ mask) < g.get_local_range()[0]) {
451+
result = op(result, tmp);
452+
}
453+
}
454+
return g.shuffle(result, 0);
455+
}
456+
405457
template <typename Group, typename V, typename T, class BinaryOperation>
406-
EnableIfIsScalarArithmetic<T> reduce(Group g, V x, T init,
407-
BinaryOperation binary_op) {
458+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
459+
reduce(Group g, V x, T init, BinaryOperation binary_op) {
408460
static_assert(sycl::detail::is_generic_group<Group>::value,
409461
"Group algorithms only support the sycl::group and "
410462
"intel::sub_group class.");
@@ -424,8 +476,8 @@ EnableIfIsScalarArithmetic<T> reduce(Group g, V x, T init,
424476
}
425477

426478
template <typename Group, typename V, typename T, class BinaryOperation>
427-
EnableIfIsVectorArithmetic<T> reduce(Group g, V x, T init,
428-
BinaryOperation binary_op) {
479+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
480+
reduce(Group g, V x, T init, BinaryOperation binary_op) {
429481
static_assert(sycl::detail::is_generic_group<Group>::value,
430482
"Group algorithms only support the sycl::group and "
431483
"intel::sub_group class.");
@@ -449,6 +501,22 @@ EnableIfIsVectorArithmetic<T> reduce(Group g, V x, T init,
449501
#endif
450502
}
451503

504+
template <typename Group, typename V, typename T, class BinaryOperation>
505+
EnableIfIsNonNativeOp<T, BinaryOperation> reduce(Group g, V x, T init,
506+
BinaryOperation op) {
507+
static_assert(sycl::detail::is_sub_group<Group>::value,
508+
"reduce algorithm with user-defined types and operators"
509+
"only supports intel::sub_group class.");
510+
T result = x;
511+
for (int mask = 1; mask < g.get_max_local_range()[0]; mask *= 2) {
512+
T tmp = g.shuffle_xor(result, id<1>(mask));
513+
if ((g.get_local_id()[0] ^ mask) < g.get_local_range()[0]) {
514+
result = op(result, tmp);
515+
}
516+
}
517+
return g.shuffle(op(init, result), 0);
518+
}
519+
452520
template <typename Group, typename Ptr, class BinaryOperation>
453521
EnableIfIsPointer<Ptr, typename Ptr::element_type>
454522
reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
@@ -509,8 +577,8 @@ EnableIfIsPointer<Ptr, T> reduce(Group g, Ptr first, Ptr last, T init,
509577
}
510578

511579
template <typename Group, typename T, class BinaryOperation>
512-
EnableIfIsScalarArithmetic<T> exclusive_scan(Group, T x,
513-
BinaryOperation binary_op) {
580+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
581+
exclusive_scan(Group, T x, BinaryOperation binary_op) {
514582
static_assert(sycl::detail::is_generic_group<Group>::value,
515583
"Group algorithms only support the sycl::group and "
516584
"intel::sub_group class.");
@@ -530,8 +598,8 @@ EnableIfIsScalarArithmetic<T> exclusive_scan(Group, T x,
530598
}
531599

532600
template <typename Group, typename T, class BinaryOperation>
533-
EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, T x,
534-
BinaryOperation binary_op) {
601+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
602+
exclusive_scan(Group g, T x, BinaryOperation binary_op) {
535603
static_assert(sycl::detail::is_generic_group<Group>::value,
536604
"Group algorithms only support the sycl::group and "
537605
"intel::sub_group class.");
@@ -550,8 +618,8 @@ EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, T x,
550618
}
551619

552620
template <typename Group, typename V, typename T, class BinaryOperation>
553-
EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, V x, T init,
554-
BinaryOperation binary_op) {
621+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
622+
exclusive_scan(Group g, V x, T init, BinaryOperation binary_op) {
555623
static_assert(sycl::detail::is_generic_group<Group>::value,
556624
"Group algorithms only support the sycl::group and "
557625
"intel::sub_group class.");
@@ -570,8 +638,8 @@ EnableIfIsVectorArithmetic<T> exclusive_scan(Group g, V x, T init,
570638
}
571639

572640
template <typename Group, typename V, typename T, class BinaryOperation>
573-
EnableIfIsScalarArithmetic<T> exclusive_scan(Group g, V x, T init,
574-
BinaryOperation binary_op) {
641+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
642+
exclusive_scan(Group g, V x, T init, BinaryOperation binary_op) {
575643
static_assert(sycl::detail::is_generic_group<Group>::value,
576644
"Group algorithms only support the sycl::group and "
577645
"intel::sub_group class.");
@@ -663,8 +731,8 @@ EnableIfIsPointer<InPtr, OutPtr> exclusive_scan(Group g, InPtr first,
663731
}
664732

665733
template <typename Group, typename T, class BinaryOperation>
666-
EnableIfIsVectorArithmetic<T> inclusive_scan(Group g, T x,
667-
BinaryOperation binary_op) {
734+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
735+
inclusive_scan(Group g, T x, BinaryOperation binary_op) {
668736
static_assert(sycl::detail::is_generic_group<Group>::value,
669737
"Group algorithms only support the sycl::group and "
670738
"intel::sub_group class.");
@@ -683,8 +751,8 @@ EnableIfIsVectorArithmetic<T> inclusive_scan(Group g, T x,
683751
}
684752

685753
template <typename Group, typename T, class BinaryOperation>
686-
EnableIfIsScalarArithmetic<T> inclusive_scan(Group, T x,
687-
BinaryOperation binary_op) {
754+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
755+
inclusive_scan(Group, T x, BinaryOperation binary_op) {
688756
static_assert(sycl::detail::is_generic_group<Group>::value,
689757
"Group algorithms only support the sycl::group and "
690758
"intel::sub_group class.");
@@ -704,7 +772,7 @@ EnableIfIsScalarArithmetic<T> inclusive_scan(Group, T x,
704772
}
705773

706774
template <typename Group, typename V, class BinaryOperation, typename T>
707-
EnableIfIsScalarArithmetic<T>
775+
EnableIfIsScalarArithmeticNativeOp<T, BinaryOperation>
708776
inclusive_scan(Group g, V x, BinaryOperation binary_op, T init) {
709777
static_assert(sycl::detail::is_generic_group<Group>::value,
710778
"Group algorithms only support the sycl::group and "
@@ -727,7 +795,7 @@ inclusive_scan(Group g, V x, BinaryOperation binary_op, T init) {
727795
}
728796

729797
template <typename Group, typename V, class BinaryOperation, typename T>
730-
EnableIfIsVectorArithmetic<T>
798+
EnableIfIsVectorArithmeticNativeOp<T, BinaryOperation>
731799
inclusive_scan(Group g, V x, BinaryOperation binary_op, T init) {
732800
static_assert(sycl::detail::is_generic_group<Group>::value,
733801
"Group algorithms only support the sycl::group and "
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// UNSUPPORTED: cuda
2+
// CUDA compilation and runtime do not yet support sub-groups.
3+
//
4+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -std=c++14 %s -o %t.out
5+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-targets=%sycl_triple -std=c++14 -D SG_GPU %s -o %t_gpu.out
6+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
7+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
8+
// RUN: %GPU_RUN_PLACEHOLDER %t_gpu.out
9+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
10+
11+
#include "helper.hpp"
12+
#include <CL/sycl.hpp>
13+
#include <complex>
14+
15+
using namespace cl::sycl;
16+
17+
template <typename T, class BinaryOperation>
18+
void check_op(queue &Queue, T init, BinaryOperation op, bool skip_init = false,
19+
size_t G = 240, size_t L = 60) {
20+
try {
21+
nd_range<1> NdRange(G, L);
22+
buffer<T> buf(G);
23+
buffer<size_t> sgsizebuf(1);
24+
Queue.submit([&](handler &cgh) {
25+
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>(cgh);
26+
auto acc = buf.template get_access<access::mode::read_write>(cgh);
27+
cgh.parallel_for(
28+
NdRange, [=](nd_item<1> NdItem) {
29+
intel::sub_group sg = NdItem.get_sub_group();
30+
if (skip_init) {
31+
acc[NdItem.get_global_id(0)] =
32+
reduce(sg, T(NdItem.get_global_id(0)), op);
33+
} else {
34+
acc[NdItem.get_global_id(0)] =
35+
reduce(sg, T(NdItem.get_global_id(0)), init, op);
36+
}
37+
if (NdItem.get_global_id(0) == 0)
38+
sgsizeacc[0] = sg.get_max_local_range()[0];
39+
});
40+
});
41+
auto acc = buf.template get_access<access::mode::read_write>();
42+
auto sgsizeacc = sgsizebuf.get_access<access::mode::read_write>();
43+
size_t sg_size = sgsizeacc[0];
44+
int WGid = -1, SGid = 0;
45+
T result = init;
46+
for (int j = 0; j < G; j++) {
47+
if (j % L % sg_size == 0) {
48+
SGid++;
49+
result = init;
50+
for (int i = j; (i % L && i % L % sg_size) || (i == j); i++) {
51+
result = op(result, T(i));
52+
}
53+
}
54+
if (j % L == 0) {
55+
WGid++;
56+
SGid = 0;
57+
}
58+
std::string name =
59+
std::string("reduce_") + typeid(BinaryOperation).name();
60+
exit_if_not_equal(acc[j], result, name.c_str());
61+
}
62+
} catch (exception e) {
63+
std::cout << "SYCL exception caught: " << e.what();
64+
exit(1);
65+
}
66+
}
67+
68+
int main() {
69+
queue Queue;
70+
if (!Queue.get_device().has_extension("cl_intel_subgroups")) {
71+
std::cout << "Skipping test\n";
72+
return 0;
73+
}
74+
75+
size_t G = 240;
76+
size_t L = 60;
77+
78+
// Test user-defined type
79+
// Use complex as a proxy for this
80+
using UDT = std::complex<float>;
81+
check_op<UDT>(Queue, UDT(L, L), intel::plus<UDT>(), false, G, L);
82+
check_op<UDT>(Queue, UDT(0, 0), intel::plus<UDT>(), true, G, L);
83+
84+
// Test user-defined operator
85+
auto UDOp = [=](const auto &lhs, const auto &rhs) { return lhs + rhs; };
86+
check_op<int>(Queue, int(L), UDOp, false, G, L);
87+
check_op<int>(Queue, int(0), UDOp, true, G, L);
88+
89+
// Test both user-defined type and operator
90+
check_op<UDT>(Queue, UDT(L, L), UDOp, false, G, L);
91+
check_op<UDT>(Queue, UDT(0, 0), UDOp, true, G, L);
92+
93+
std::cout << "Test passed." << std::endl;
94+
return 0;
95+
}

0 commit comments

Comments
 (0)