Skip to content

[SYCL] Raise bit_cast to SYCL namespace #3524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/include/CL/sycl/ONEAPI/atomic_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ struct bit_equal<T, typename detail::enable_if_t<std::is_integral<T>::value>> {

template <> struct bit_equal<float> {
bool operator()(const float &lhs, const float &rhs) {
auto LhsInt = detail::bit_cast<uint32_t>(lhs);
auto RhsInt = detail::bit_cast<uint32_t>(rhs);
auto LhsInt = bit_cast<uint32_t>(lhs);
auto RhsInt = bit_cast<uint32_t>(rhs);
return LhsInt == RhsInt;
}
};

template <> struct bit_equal<double> {
bool operator()(const double &lhs, const double &rhs) {
auto LhsInt = detail::bit_cast<uint64_t>(lhs);
auto RhsInt = detail::bit_cast<uint64_t>(rhs);
auto LhsInt = bit_cast<uint64_t>(lhs);
auto RhsInt = bit_cast<uint64_t>(rhs);
return LhsInt == RhsInt;
}
};
Expand Down
8 changes: 4 additions & 4 deletions sycl/include/CL/sycl/ONEAPI/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ T load(const multi_ptr<T, Space> src) {
BlockT Ret =
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));

return sycl::detail::bit_cast<T>(Ret);
return sycl::bit_cast<T>(Ret);
}

template <int N, typename T, access::address_space Space>
Expand All @@ -68,7 +68,7 @@ vec<T, N> load(const multi_ptr<T, Space> src) {
VecT Ret =
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));

return sycl::detail::bit_cast<typename vec<T, N>::vector_t>(Ret);
return sycl::bit_cast<typename vec<T, N>::vector_t>(Ret);
}

template <typename T, access::address_space Space>
Expand All @@ -77,7 +77,7 @@ void store(multi_ptr<T, Space> dst, const T &x) {
using PtrT = sycl::detail::ConvertToOpenCLType_t<multi_ptr<BlockT, Space>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
sycl::detail::bit_cast<BlockT>(x));
sycl::bit_cast<BlockT>(x));
}

template <int N, typename T, access::address_space Space>
Expand All @@ -88,7 +88,7 @@ void store(multi_ptr<T, Space> dst, const vec<T, N> &x) {
sycl::detail::ConvertToOpenCLType_t<const multi_ptr<BlockT, Space>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
sycl::detail::bit_cast<VecT>(x));
sycl::bit_cast<VecT>(x));
}
#endif // __SYCL_DEVICE_ONLY__

Expand Down
2 changes: 1 addition & 1 deletion sycl/include/CL/sycl/atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class atomic {
Ptr);
cl_int TmpVal = __spirv_AtomicLoad(
TmpPtr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order));
cl_float ResVal = detail::bit_cast<cl_float>(TmpVal);
cl_float ResVal = bit_cast<cl_float>(TmpVal);
return ResVal;
}
#else
Expand Down
56 changes: 56 additions & 0 deletions sycl/include/CL/sycl/bit_cast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//==---------------- helpers.hpp - SYCL helpers ----------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <type_traits>

#if __cpp_lib_bit_cast
#include <bit>
#endif

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {

// forward decl
namespace detail {
inline void memcpy(void *Dst, const void *Src, std::size_t Size);
}

// sycl::bit_cast ( no longer sycl::detail::bit_cast )
template <typename To, typename From>
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
constexpr
#endif
To
bit_cast(const From &from) noexcept {
static_assert(sizeof(To) == sizeof(From),
"Sizes of To and From must be equal");
static_assert(std::is_trivially_copyable<From>::value,
"From must be trivially copyable");
static_assert(std::is_trivially_copyable<To>::value,
"To must be trivially copyable");
#if __cpp_lib_bit_cast
return std::bit_cast<To>(from);
#else // __cpp_lib_bit_cast

#if __has_builtin(__builtin_bit_cast)
return __builtin_bit_cast(To, from);
#else // __has_builtin(__builtin_bit_cast)
static_assert(std::is_trivially_default_constructible<To>::value,
"To must be trivially default constructible");
To to;
sycl::detail::memcpy(&to, &from, sizeof(To));
return to;
#endif // __has_builtin(__builtin_bit_cast)

#endif // __cpp_lib_bit_cast
}

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
32 changes: 1 addition & 31 deletions sycl/include/CL/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
#include <CL/sycl/detail/pi.hpp>
#include <CL/sycl/detail/type_traits.hpp>

#if __cpp_lib_bit_cast
#include <bit>
#endif
#include <memory>
#include <stdexcept>
#include <type_traits>
Expand All @@ -45,34 +42,6 @@ inline void memcpy(void *Dst, const void *Src, size_t Size) {
}
}

template <typename To, typename From>
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
constexpr
#endif
To
bit_cast(const From &from) noexcept {
static_assert(sizeof(To) == sizeof(From),
"Sizes of To and From must be equal");
static_assert(std::is_trivially_copyable<From>::value,
"From must be trivially copyable");
static_assert(std::is_trivially_copyable<To>::value,
"To must be trivially copyable");
#if __cpp_lib_bit_cast
return std::bit_cast<To>(from);
#else // __cpp_lib_bit_cast

#if __has_builtin(__builtin_bit_cast)
return __builtin_bit_cast(To, from);
#else // __has_builtin(__builtin_bit_cast)
static_assert(std::is_trivially_default_constructible<To>::value,
"To must be trivially default constructible");
To to;
sycl::detail::memcpy(&to, &from, sizeof(To));
return to;
#endif // __has_builtin(__builtin_bit_cast)

#endif // __cpp_lib_bit_cast
}

class context_impl;
// The function returns list of events that can be passed to OpenCL API as
Expand Down Expand Up @@ -272,5 +241,6 @@ getSPIRVMemorySemanticsMask(const access::fence_space AccessSpace,
}

} // namespace detail

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
38 changes: 19 additions & 19 deletions sycl/include/CL/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}
template <typename Group, typename T, typename IdT>
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
Expand Down Expand Up @@ -190,11 +190,11 @@ EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
for (int i = 0; i < Dimensions; ++i) {
VecId[i] = local_id[Dimensions - i - 1];
}
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}
template <typename Group, typename T, int Dimensions>
EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
Expand Down Expand Up @@ -284,11 +284,11 @@ AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr,
auto *PtrInt =
reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
MPtr.get());
I DesiredInt = detail::bit_cast<I>(Desired);
I ExpectedInt = detail::bit_cast<I>(Expected);
I DesiredInt = bit_cast<I>(Desired);
I ExpectedInt = bit_cast<I>(Expected);
I ResultInt = __spirv_AtomicCompareExchange(
PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
return detail::bit_cast<T>(ResultInt);
return bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace>
Expand All @@ -312,7 +312,7 @@ AtomicLoad(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ResultInt = __spirv_AtomicLoad(PtrInt, SPIRVScope, SPIRVOrder);
return detail::bit_cast<T>(ResultInt);
return bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace>
Expand All @@ -335,7 +335,7 @@ AtomicStore(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
MPtr.get());
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ValueInt = detail::bit_cast<I>(Value);
I ValueInt = bit_cast<I>(Value);
__spirv_AtomicStore(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
}

Expand All @@ -359,10 +359,10 @@ AtomicExchange(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
MPtr.get());
auto SPIRVOrder = getMemorySemanticsMask(Order);
auto SPIRVScope = getScope(Scope);
I ValueInt = detail::bit_cast<I>(Value);
I ValueInt = bit_cast<I>(Value);
I ResultInt =
__spirv_AtomicExchange(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
return detail::bit_cast<T>(ResultInt);
return bit_cast<T>(ResultInt);
}

template <typename T, access::address_space AddressSpace>
Expand Down Expand Up @@ -600,57 +600,57 @@ using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
auto ShuffleX = bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleINTEL(
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
#endif
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
auto ShuffleX = bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
#endif
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
auto ShuffleX = bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
#endif
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}

template <typename T>
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
using ShuffleT = ConvertToNativeShuffleType_t<T>;
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
auto ShuffleX = bit_cast<ShuffleT>(x);
#ifndef __NVPTX__
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
#else
ShuffleT Result =
__nvvm_shfl_sync_up_i32(membermask(), ShuffleX, local_id.get(0), 0);
#endif
return detail::bit_cast<T>(Result);
return bit_cast<T>(Result);
}

// Generic shuffles may require multiple calls to SubgroupShuffle
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/stl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <CL/sycl/detail/defines.hpp>

#include "bit_cast.hpp"
#include <exception>
#include <functional>
#include <memory>
Expand Down
3 changes: 1 addition & 2 deletions sycl/test/bit_cast/bit_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ To doBitCast(const From &ValueToConvert) {
Queue.submit([&](sycl::handler &cgh) {
auto acc = Buf.template get_access<sycl_write>(cgh);
cgh.single_task<class BitCastKernel<To, From>>([=]() {
// TODO: change to sycl::bit_cast in the future
acc[0] = sycl::detail::bit_cast<To>(ValueToConvert);
acc[0] = sycl::bit_cast<To>(ValueToConvert);
});
});
}
Expand Down