Skip to content

Commit 7e9819d

Browse files
[SYCL] Refactor builtins implementation (#11956)
See `builtins_preview.hpp` for the outline of the new design. This PR changes the implementation under `-fpreview-breaking-changes` and removes reliance on a python builtins generator script. Suggested reading/review order: `builtins_preview.hpp`, `helper_macros.hpp`, `host_helper_macros.hpp`, then headers implementing user-visible side with the library implementation `sycl/source/builtins/*_functions.cpp` last.
1 parent 5bb9a44 commit 7e9819d

23 files changed

+2874
-27
lines changed

sycl/include/sycl/builtins.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212

1313
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1414

15-
// Include the generated builtins.
16-
#include <sycl/builtins_marray_gen.hpp>
17-
#include <sycl/builtins_scalar_gen.hpp>
18-
#include <sycl/builtins_vector_gen.hpp>
15+
#include <sycl/builtins_preview.hpp>
1916

2017
#else // __INTEL_PREVIEW_BREAKING_CHANGES
2118

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
//==------------------- builtins_preview.hpp -------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Implement SYCL builtin functions. This implementation is mainly driven by the
10+
// requirement of not including <cmath> anywhere in the SYCL headers (i.e. from
11+
// within <sycl/sycl.hpp>), because it pollutes global namespace. Note that we
12+
// can avoid that using MSVC's STL as the pollution happens even from
13+
// <vector>/<string> and other headers that have to be included per the SYCL
14+
// specification. As such, an alternative approach might be to use math
15+
// intrinsics with GCC/clang-based compilers and use <cmath> when using MSVC as
16+
// a host compiler. That hasn't been tried/investigated.
17+
//
18+
// Current implementation splits builtins into several files following the SYCL
19+
// 2020 (revision 8) split into common/math/geometric/relational/etc. functions.
20+
// For each set, the implementation is split into a user-visible
21+
// include/sycl/detail/builtins/*_functions.hpp providing full device-side
22+
// implementation as well as defining user-visible APIs and defining ABI
23+
// implemented under source/builtins/*_functions.cpp for the host side. We
24+
// provide both scalar/vector overloads through symbols in the SYCL runtime
25+
// library due to the <cmath> limitation above (for scalars) and due to
26+
// performance reasons for vector overloads (to be able to benefit from
27+
// vectorization).
28+
//
29+
// Providing declaration for the host side symbols contained in the library
30+
// comes with its own challenges. One is compilation time - blindly providing
31+
// all those declarations takes significant time (about 10% slowdown for
32+
// "clang++ -fsycl" when compiling just "#include <sycl/sycl.hpp>"). Another
33+
// issue is that return type for templates is part of the mangling (and as such
34+
// SFINAE requirements too). To overcome that we structure host side
35+
// implementation roughly like this (in most cases):
36+
//
37+
// math_function.cpp exports:
38+
// float sycl::__sin_impl(float);
39+
// float1 sycl::__sin_impl(float1);
40+
// float2 sycl::__sin_impl(float2);
41+
// ...
42+
// /* same for other types */
43+
//
44+
// math_functions.hpp provide an implementation based on the following idea (in
45+
// ::sycl namespace):
46+
// float sin(float x) {
47+
// extern __sin_impl(float);
48+
// return __sin_impl(x);
49+
// }
50+
// template <typename T>
51+
// enable_if_valid_type<T> sin(T x) {
52+
// if constexpr (marray_or_swizzle) {
53+
// ...
54+
// call sycl::sin(vector_or_scalar)
55+
// } else {
56+
// extern T __sin_impl(T);
57+
// return __sin_impl(x);
58+
// }
59+
// }
60+
// That way we avoid having the full set of explicit declaration for the symbols
61+
// in the library and instead only pay with compile time when those template
62+
// instantiations actually happen.
63+
64+
#pragma once
65+
66+
#include <sycl/builtins_utils_vec.hpp>
67+
68+
namespace sycl {
69+
inline namespace _V1 {
70+
namespace detail {
71+
template <typename... Ts>
72+
inline constexpr bool builtin_same_shape_v =
73+
((... && is_scalar_arithmetic_v<Ts>) || (... && is_marray_v<Ts>) ||
74+
(... && is_vec_or_swizzle_v<Ts>)) &&
75+
(... && (num_elements<Ts>::value ==
76+
num_elements<typename first_type<Ts...>::type>::value));
77+
78+
template <typename... Ts>
79+
inline constexpr bool builtin_same_or_swizzle_v =
80+
// Use builtin_same_shape_v to filter out types unrelated to builtins.
81+
builtin_same_shape_v<Ts...> && all_same_v<simplify_if_swizzle_t<Ts>...>;
82+
83+
namespace builtins {
84+
#ifdef __SYCL_DEVICE_ONLY__
85+
template <typename T> auto convert_arg(T &&x) {
86+
using no_cv_ref = std::remove_cv_t<std::remove_reference_t<T>>;
87+
if constexpr (is_vec_v<no_cv_ref>) {
88+
using elem_type = get_elem_type_t<no_cv_ref>;
89+
using converted_elem_type =
90+
decltype(convert_arg(std::declval<elem_type>()));
91+
92+
constexpr auto N = no_cv_ref::size();
93+
using result_type = std::conditional_t<N == 1, converted_elem_type,
94+
converted_elem_type
95+
__attribute__((ext_vector_type(N)))>;
96+
// TODO: We should have this bit_cast impl inside vec::convert.
97+
return bit_cast<result_type>(static_cast<typename no_cv_ref::vector_t>(x));
98+
} else if constexpr (std::is_same_v<no_cv_ref, half>)
99+
return static_cast<half_impl::BIsRepresentationT>(x);
100+
else if constexpr (is_multi_ptr_v<no_cv_ref>) {
101+
return convert_arg(x.get_decorated());
102+
} else if constexpr (is_scalar_arithmetic_v<no_cv_ref>) {
103+
// E.g. on linux: long long -> int64_t (long), or char -> int8_t (signed
104+
// char) and same for unsigned; Windows has long/long long reversed.
105+
// TODO: Inline this scalar impl.
106+
return static_cast<ConvertToOpenCLType_t<no_cv_ref>>(x);
107+
} else if constexpr (std::is_pointer_v<no_cv_ref>) {
108+
using elem_type = remove_decoration_t<std::remove_pointer_t<no_cv_ref>>;
109+
using converted_elem_type =
110+
decltype(convert_arg(std::declval<elem_type>()));
111+
using result_type =
112+
typename DecoratedType<converted_elem_type,
113+
deduce_AS<no_cv_ref>::value>::type *;
114+
return reinterpret_cast<result_type>(x);
115+
} else if constexpr (is_swizzle_v<no_cv_ref>) {
116+
return convert_arg(simplify_if_swizzle_t<no_cv_ref>{x});
117+
} else {
118+
// TODO: should it be unreachable? What can it be?
119+
return std::forward<T>(x);
120+
}
121+
}
122+
123+
template <typename RetTy, typename T> auto convert_result(T &&x) {
124+
if constexpr (is_vec_v<RetTy>) {
125+
return bit_cast<typename RetTy::vector_t>(x);
126+
} else {
127+
return std::forward<T>(x);
128+
}
129+
}
130+
#endif
131+
} // namespace builtins
132+
133+
template <typename FuncTy, typename... Ts>
134+
auto builtin_marray_impl(FuncTy F, const Ts &...x) {
135+
using ret_elem_type = decltype(F(x[0]...));
136+
using T = typename first_type<Ts...>::type;
137+
marray<ret_elem_type, T::size()> Res;
138+
constexpr auto N = T::size();
139+
for (size_t I = 0; I < N / 2; ++I) {
140+
auto PartialRes = F(to_vec2(x, I * 2)...);
141+
std::memcpy(&Res[I * 2], &PartialRes, sizeof(decltype(PartialRes)));
142+
}
143+
if (N % 2)
144+
Res[N - 1] = F(x[N - 1]...);
145+
return Res;
146+
}
147+
148+
template <typename FuncTy, typename... Ts>
149+
auto builtin_default_host_impl(FuncTy F, const Ts &...x) {
150+
// We implement support for marray/swizzle in the headers and export symbols
151+
// for scalars/vector from the library binary. The reason is that scalar
152+
// implementations mostly depend on <cmath> which pollutes global namespace,
153+
// so we can't unconditionally include it from the SYCL headers. Vector
154+
// overloads have to be implemented in the library next to scalar overloads in
155+
// order to be vectorizable.
156+
if constexpr ((... || is_marray_v<Ts>)) {
157+
return builtin_marray_impl(F, x...);
158+
} else {
159+
return F(simplify_if_swizzle_t<Ts>{x}...);
160+
}
161+
}
162+
163+
template <typename FuncTy, typename... Ts>
164+
auto builtin_delegate_to_scalar(FuncTy F, const Ts &...x) {
165+
using T = typename first_type<Ts...>::type;
166+
if constexpr (is_vec_or_swizzle_v<T>) {
167+
using ret_elem_type = decltype(F(x[0]...));
168+
// TODO: using r{} to avoid Werror. Not sure if ok.
169+
vec<ret_elem_type, T::size()> r{};
170+
loop<T::size()>([&](auto idx) { r[idx] = F(x[idx]...); });
171+
return r;
172+
} else {
173+
static_assert(is_marray_v<T>);
174+
return builtin_marray_impl(F, x...);
175+
}
176+
}
177+
178+
template <typename T>
179+
struct any_elem_type
180+
: std::bool_constant<check_type_in_v<
181+
get_elem_type_t<T>, float, double, half, char, signed char, short,
182+
int, long, long long, unsigned char, unsigned short, unsigned int,
183+
unsigned long, unsigned long long>> {};
184+
template <typename T>
185+
struct fp_elem_type
186+
: std::bool_constant<
187+
check_type_in_v<get_elem_type_t<T>, float, double, half>> {};
188+
template <typename T>
189+
struct float_elem_type
190+
: std::bool_constant<check_type_in_v<get_elem_type_t<T>, float>> {};
191+
template <typename T>
192+
struct integer_elem_type
193+
: std::bool_constant<
194+
check_type_in_v<get_elem_type_t<T>, char, signed char, short, int,
195+
long, long long, unsigned char, unsigned short,
196+
unsigned int, unsigned long, unsigned long long>> {};
197+
template <typename T>
198+
struct suint32_elem_type
199+
: std::bool_constant<
200+
check_type_in_v<get_elem_type_t<T>, int32_t, uint32_t>> {};
201+
202+
template <typename... Ts>
203+
struct same_basic_shape : std::bool_constant<builtin_same_shape_v<Ts...>> {};
204+
205+
template <typename... Ts>
206+
struct same_elem_type : std::bool_constant<same_basic_shape<Ts...>::value &&
207+
all_same_v<get_elem_type_t<Ts>...>> {
208+
};
209+
210+
template <typename> struct any_shape : std::true_type {};
211+
212+
template <typename T>
213+
struct scalar_only : std::bool_constant<is_scalar_arithmetic_v<T>> {};
214+
215+
template <typename T>
216+
struct non_scalar_only : std::bool_constant<!is_scalar_arithmetic_v<T>> {};
217+
218+
template <typename T> struct default_ret_type {
219+
using type = T;
220+
};
221+
222+
template <typename T> struct scalar_ret_type {
223+
using type = get_elem_type_t<T>;
224+
};
225+
226+
template <template <typename> typename RetTypeTrait,
227+
template <typename> typename ElemTypeChecker,
228+
template <typename> typename ShapeChecker,
229+
template <typename...> typename ExtraConditions, typename... Ts>
230+
struct builtin_enable
231+
: std::enable_if<
232+
ElemTypeChecker<typename first_type<Ts...>::type>::value &&
233+
ShapeChecker<typename first_type<Ts...>::type>::value &&
234+
ExtraConditions<Ts...>::value,
235+
typename RetTypeTrait<
236+
simplify_if_swizzle_t<typename first_type<Ts...>::type>>::type> {
237+
};
238+
#define BUILTIN_CREATE_ENABLER(NAME, RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \
239+
SHAPE_CHECKER, EXTRA_CONDITIONS) \
240+
namespace detail { \
241+
template <typename... Ts> \
242+
using NAME##_t = \
243+
typename builtin_enable<RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \
244+
SHAPE_CHECKER, EXTRA_CONDITIONS, Ts...>::type; \
245+
}
246+
} // namespace detail
247+
248+
BUILTIN_CREATE_ENABLER(builtin_enable_generic, default_ret_type, any_elem_type,
249+
any_shape, same_elem_type)
250+
BUILTIN_CREATE_ENABLER(builtin_enable_generic_scalar, default_ret_type,
251+
any_elem_type, scalar_only, same_elem_type)
252+
BUILTIN_CREATE_ENABLER(builtin_enable_generic_non_scalar, default_ret_type,
253+
any_elem_type, non_scalar_only, same_elem_type)
254+
} // namespace _V1
255+
} // namespace sycl
256+
257+
// The headers below are specifically implemented without including all the
258+
// necessary headers to allow preprocessing them on their own and providing
259+
// human-friendly result. One can use a command like this to achieve that:
260+
// clang++ -[DU]__SYCL_DEVICE_ONLY__ -x c++ math_functions.inc \
261+
// -I <..>/llvm/sycl/include -E -o - \
262+
// | grep -v '^#' | clang-format > math_functions.{host|device}.ii
263+
264+
#include <sycl/detail/builtins/common_functions.inc>
265+
#include <sycl/detail/builtins/geometric_functions.inc>
266+
#include <sycl/detail/builtins/half_precision_math_functions.inc>
267+
#include <sycl/detail/builtins/integer_functions.inc>
268+
#include <sycl/detail/builtins/math_functions.inc>
269+
#include <sycl/detail/builtins/native_math_functions.inc>
270+
#include <sycl/detail/builtins/relational_functions.inc>
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//==------------------- common_functions.hpp -------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Intentionally insufficient set of includes and no "#pragma once".
10+
11+
#include <sycl/detail/builtins/helper_macros.hpp>
12+
13+
namespace sycl {
14+
inline namespace _V1 {
15+
BUILTIN_CREATE_ENABLER(builtin_enable_common, default_ret_type, fp_elem_type,
16+
any_shape, same_elem_type)
17+
BUILTIN_CREATE_ENABLER(builtin_enable_common_non_scalar, default_ret_type,
18+
fp_elem_type, non_scalar_only, same_elem_type)
19+
20+
#ifdef __SYCL_DEVICE_ONLY__
21+
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \
22+
DEVICE_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, SPIRV_IMPL)
23+
#else
24+
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \
25+
HOST_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, common, \
26+
default_ret_type)
27+
#endif
28+
29+
BUILTIN_COMMON(ONE_ARG, degrees, __spirv_ocl_degrees)
30+
BUILTIN_COMMON(ONE_ARG, radians, __spirv_ocl_radians)
31+
BUILTIN_COMMON(ONE_ARG, sign, __spirv_ocl_sign)
32+
33+
BUILTIN_COMMON(THREE_ARGS, mix, __spirv_ocl_mix)
34+
template <typename T0, typename T1>
35+
detail::builtin_enable_common_non_scalar_t<T0, T1>
36+
mix(T0 x, T1 y, detail::get_elem_type_t<T0> z) {
37+
return mix(detail::simplify_if_swizzle_t<T0>{x},
38+
detail::simplify_if_swizzle_t<T0>{y},
39+
detail::simplify_if_swizzle_t<T0>{z});
40+
}
41+
42+
BUILTIN_COMMON(TWO_ARGS, step, __spirv_ocl_step)
43+
template <typename T>
44+
detail::builtin_enable_common_non_scalar_t<T> step(detail::get_elem_type_t<T> x,
45+
T y) {
46+
return step(detail::simplify_if_swizzle_t<T>{x},
47+
detail::simplify_if_swizzle_t<T>{y});
48+
}
49+
50+
BUILTIN_COMMON(THREE_ARGS, smoothstep, __spirv_ocl_smoothstep)
51+
template <typename T>
52+
detail::builtin_enable_common_non_scalar_t<T>
53+
smoothstep(detail::get_elem_type_t<T> x, detail::get_elem_type_t<T> y, T z) {
54+
return smoothstep(detail::simplify_if_swizzle_t<T>{x},
55+
detail::simplify_if_swizzle_t<T>{y},
56+
detail::simplify_if_swizzle_t<T>{z});
57+
}
58+
59+
BUILTIN_COMMON(TWO_ARGS, max, __spirv_ocl_fmax_common)
60+
template <typename T>
61+
detail::builtin_enable_common_non_scalar_t<T>
62+
max(T x, detail::get_elem_type_t<T> y) {
63+
return max(detail::simplify_if_swizzle_t<T>{x},
64+
detail::simplify_if_swizzle_t<T>{y});
65+
}
66+
67+
BUILTIN_COMMON(TWO_ARGS, min, __spirv_ocl_fmin_common)
68+
template <typename T>
69+
detail::builtin_enable_common_non_scalar_t<T>
70+
min(T x, detail::get_elem_type_t<T> y) {
71+
return min(detail::simplify_if_swizzle_t<T>{x},
72+
detail::simplify_if_swizzle_t<T>{y});
73+
}
74+
75+
#undef BUILTIN_COMMON
76+
77+
#ifdef __SYCL_DEVICE_ONLY__
78+
DEVICE_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t,
79+
[](auto... xs) {
80+
using ElemTy = detail::get_elem_type_t<T0>;
81+
if constexpr (std::is_integral_v<ElemTy>) {
82+
if constexpr (std::is_signed_v<ElemTy>) {
83+
return __spirv_ocl_s_clamp(xs...);
84+
} else {
85+
return __spirv_ocl_u_clamp(xs...);
86+
}
87+
} else {
88+
return __spirv_ocl_fclamp(xs...);
89+
}
90+
})
91+
#else
92+
HOST_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t, common,
93+
default_ret_type)
94+
#endif
95+
template <typename T>
96+
detail::builtin_enable_generic_non_scalar_t<T>
97+
clamp(T x, detail::get_elem_type_t<T> y, detail::get_elem_type_t<T> z) {
98+
return clamp(detail::simplify_if_swizzle_t<T>{x},
99+
detail::simplify_if_swizzle_t<T>{y},
100+
detail::simplify_if_swizzle_t<T>{z});
101+
}
102+
} // namespace _V1
103+
} // namespace sycl

0 commit comments

Comments
 (0)