diff --git a/sycl/include/sycl/half_type.hpp b/sycl/include/sycl/half_type.hpp index f3a04684c5f58..951146f2cdfbb 100644 --- a/sycl/include/sycl/half_type.hpp +++ b/sycl/include/sycl/half_type.hpp @@ -228,9 +228,9 @@ class half; // Several aliases are defined below: // - StorageT: actual representation of half data type. It is used by scalar -// half values and by 'sycl::vec' class. On device side, it points to some -// native half data type, while on host some custom data type is used to -// emulate operations of 16-bit floating-point values +// half values. On device side, it points to some native half data type, while +// on host some custom data type is used to emulate operations of 16-bit +// floating-point values // // - BIsRepresentationT: data type which is used by built-in functions. It is // distinguished from StorageT, because on host, we can still operate on the @@ -238,32 +238,38 @@ class half; // type (too many changes required for BIs implementation without any // foreseeable profits) // -// - VecNStorageT - representation of N-element vector of halfs. Follows the -// same logic as StorageT +// - VecElemT: representation of each element in the vector. On device it is +// the same as StorageT to carry a native vector representation, while on +// host it stores the sycl::half implementation directly. +// +// - VecNStorageT: representation of N-element vector of halfs. Follows the +// same logic as VecElemT. #ifdef __SYCL_DEVICE_ONLY__ using StorageT = _Float16; using BIsRepresentationT = _Float16; +using VecElemT = _Float16; -using Vec2StorageT = StorageT __attribute__((ext_vector_type(2))); -using Vec3StorageT = StorageT __attribute__((ext_vector_type(3))); -using Vec4StorageT = StorageT __attribute__((ext_vector_type(4))); -using Vec8StorageT = StorageT __attribute__((ext_vector_type(8))); -using Vec16StorageT = StorageT __attribute__((ext_vector_type(16))); +using Vec2StorageT = VecElemT __attribute__((ext_vector_type(2))); +using Vec3StorageT = VecElemT __attribute__((ext_vector_type(3))); +using Vec4StorageT = VecElemT __attribute__((ext_vector_type(4))); +using Vec8StorageT = VecElemT __attribute__((ext_vector_type(8))); +using Vec16StorageT = VecElemT __attribute__((ext_vector_type(16))); #else // SYCL_DEVICE_ONLY using StorageT = detail::host_half_impl::half; // No need to extract underlying data type for built-in functions operating on // host using BIsRepresentationT = half; +using VecElemT = half; // On the host side we cannot use OpenCL cl_half# types as an underlying type // for vec because they are actually defined as an integer type under the // hood. As a result half values will be converted to the integer and passed // as a kernel argument which is expected to be floating point number. -using Vec2StorageT = std::array; -using Vec3StorageT = std::array; -using Vec4StorageT = std::array; -using Vec8StorageT = std::array; -using Vec16StorageT = std::array; +using Vec2StorageT = std::array; +using Vec3StorageT = std::array; +using Vec4StorageT = std::array; +using Vec8StorageT = std::array; +using Vec16StorageT = std::array; #endif // SYCL_DEVICE_ONLY #ifndef __SYCL_DEVICE_ONLY__ diff --git a/sycl/include/sycl/vector.hpp b/sycl/include/sycl/vector.hpp index f1bf7fcfcc24d..7c4c509062b37 100644 --- a/sycl/include/sycl/vector.hpp +++ b/sycl/include/sycl/vector.hpp @@ -300,9 +300,9 @@ struct VecStorage< // Single element half template <> struct VecStorage { - using DataType = sycl::detail::half_impl::StorageT; + using DataType = sycl::detail::half_impl::VecElemT; #ifdef __SYCL_DEVICE_ONLY__ - using VectorDataType = sycl::detail::half_impl::StorageT; + using VectorDataType = sycl::detail::half_impl::VecElemT; #endif // __SYCL_DEVICE_ONLY__ }; @@ -365,10 +365,12 @@ template class vec { // in the class, so vec should be equal to float16 in memory. using DataType = typename detail::VecStorage::DataType; +#ifdef __SYCL_DEVICE_ONLY__ + static constexpr bool IsHostHalf = false; +#else static constexpr bool IsHostHalf = - std::is_same_v && - std::is_same_v; + std::is_same_v; +#endif static constexpr bool IsBfloat16 = std::is_same_v;