Skip to content

Commit 5df262f

Browse files
[SYCL] Change internal host vec representation (#13596)
This commit changes the internal representation of sycl::vec<sycl::half, N> to use sycl::half directly instead of the wrapper implementation used inside them. This avoids strict alias violation in the operator[] for the host implementation. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent e425d2a commit 5df262f

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

sycl/include/sycl/half_type.hpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -228,42 +228,48 @@ class half;
228228

229229
// Several aliases are defined below:
230230
// - StorageT: actual representation of half data type. It is used by scalar
231-
// half values and by 'sycl::vec' class. On device side, it points to some
232-
// native half data type, while on host some custom data type is used to
233-
// emulate operations of 16-bit floating-point values
231+
// half values. On device side, it points to some native half data type, while
232+
// on host some custom data type is used to emulate operations of 16-bit
233+
// floating-point values
234234
//
235235
// - BIsRepresentationT: data type which is used by built-in functions. It is
236236
// distinguished from StorageT, because on host, we can still operate on the
237237
// wrapper itself and there is no sense in direct usage of underlying data
238238
// type (too many changes required for BIs implementation without any
239239
// foreseeable profits)
240240
//
241-
// - VecNStorageT - representation of N-element vector of halfs. Follows the
242-
// same logic as StorageT
241+
// - VecElemT: representation of each element in the vector. On device it is
242+
// the same as StorageT to carry a native vector representation, while on
243+
// host it stores the sycl::half implementation directly.
244+
//
245+
// - VecNStorageT: representation of N-element vector of halfs. Follows the
246+
// same logic as VecElemT.
243247
#ifdef __SYCL_DEVICE_ONLY__
244248
using StorageT = _Float16;
245249
using BIsRepresentationT = _Float16;
250+
using VecElemT = _Float16;
246251

247-
using Vec2StorageT = StorageT __attribute__((ext_vector_type(2)));
248-
using Vec3StorageT = StorageT __attribute__((ext_vector_type(3)));
249-
using Vec4StorageT = StorageT __attribute__((ext_vector_type(4)));
250-
using Vec8StorageT = StorageT __attribute__((ext_vector_type(8)));
251-
using Vec16StorageT = StorageT __attribute__((ext_vector_type(16)));
252+
using Vec2StorageT = VecElemT __attribute__((ext_vector_type(2)));
253+
using Vec3StorageT = VecElemT __attribute__((ext_vector_type(3)));
254+
using Vec4StorageT = VecElemT __attribute__((ext_vector_type(4)));
255+
using Vec8StorageT = VecElemT __attribute__((ext_vector_type(8)));
256+
using Vec16StorageT = VecElemT __attribute__((ext_vector_type(16)));
252257
#else // SYCL_DEVICE_ONLY
253258
using StorageT = detail::host_half_impl::half;
254259
// No need to extract underlying data type for built-in functions operating on
255260
// host
256261
using BIsRepresentationT = half;
262+
using VecElemT = half;
257263

258264
// On the host side we cannot use OpenCL cl_half# types as an underlying type
259265
// for vec because they are actually defined as an integer type under the
260266
// hood. As a result half values will be converted to the integer and passed
261267
// as a kernel argument which is expected to be floating point number.
262-
using Vec2StorageT = std::array<StorageT, 2>;
263-
using Vec3StorageT = std::array<StorageT, 3>;
264-
using Vec4StorageT = std::array<StorageT, 4>;
265-
using Vec8StorageT = std::array<StorageT, 8>;
266-
using Vec16StorageT = std::array<StorageT, 16>;
268+
using Vec2StorageT = std::array<VecElemT, 2>;
269+
using Vec3StorageT = std::array<VecElemT, 3>;
270+
using Vec4StorageT = std::array<VecElemT, 4>;
271+
using Vec8StorageT = std::array<VecElemT, 8>;
272+
using Vec16StorageT = std::array<VecElemT, 16>;
267273
#endif // SYCL_DEVICE_ONLY
268274

269275
#ifndef __SYCL_DEVICE_ONLY__

sycl/include/sycl/vector.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,9 @@ struct VecStorage<
300300

301301
// Single element half
302302
template <> struct VecStorage<half, 1, void> {
303-
using DataType = sycl::detail::half_impl::StorageT;
303+
using DataType = sycl::detail::half_impl::VecElemT;
304304
#ifdef __SYCL_DEVICE_ONLY__
305-
using VectorDataType = sycl::detail::half_impl::StorageT;
305+
using VectorDataType = sycl::detail::half_impl::VecElemT;
306306
#endif // __SYCL_DEVICE_ONLY__
307307
};
308308

@@ -365,10 +365,12 @@ template <typename Type, int NumElements> class vec {
365365
// in the class, so vec<float, 16> should be equal to float16 in memory.
366366
using DataType = typename detail::VecStorage<DataT, NumElements>::DataType;
367367

368+
#ifdef __SYCL_DEVICE_ONLY__
369+
static constexpr bool IsHostHalf = false;
370+
#else
368371
static constexpr bool IsHostHalf =
369-
std::is_same_v<DataT, sycl::detail::half_impl::half> &&
370-
std::is_same_v<sycl::detail::half_impl::StorageT,
371-
sycl::detail::host_half_impl::half>;
372+
std::is_same_v<DataT, sycl::detail::half_impl::half>;
373+
#endif
372374

373375
static constexpr bool IsBfloat16 =
374376
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;

0 commit comments

Comments
 (0)