Skip to content

Commit d4b66bd

Browse files
[SYCL] Raise bit_cast to SYCL namespace (#3524)
SYCL2020 has bit_cast provided in SYCL namespace now. Moved it and updated tests and internal references Signed-off-by: Chris Perkins <[email protected]>
1 parent 6df94f2 commit d4b66bd

File tree

8 files changed

+98
-61
lines changed

8 files changed

+98
-61
lines changed

sycl/include/CL/sycl/ONEAPI/atomic_ref.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ struct bit_equal<T, typename detail::enable_if_t<std::is_integral<T>::value>> {
9494

9595
template <> struct bit_equal<float> {
9696
bool operator()(const float &lhs, const float &rhs) {
97-
auto LhsInt = detail::bit_cast<uint32_t>(lhs);
98-
auto RhsInt = detail::bit_cast<uint32_t>(rhs);
97+
auto LhsInt = sycl::bit_cast<uint32_t>(lhs);
98+
auto RhsInt = sycl::bit_cast<uint32_t>(rhs);
9999
return LhsInt == RhsInt;
100100
}
101101
};
102102

103103
template <> struct bit_equal<double> {
104104
bool operator()(const double &lhs, const double &rhs) {
105-
auto LhsInt = detail::bit_cast<uint64_t>(lhs);
106-
auto RhsInt = detail::bit_cast<uint64_t>(rhs);
105+
auto LhsInt = sycl::bit_cast<uint64_t>(lhs);
106+
auto RhsInt = sycl::bit_cast<uint64_t>(rhs);
107107
return LhsInt == RhsInt;
108108
}
109109
};

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ T load(const multi_ptr<T, Space> src) {
5555
BlockT Ret =
5656
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));
5757

58-
return sycl::detail::bit_cast<T>(Ret);
58+
return sycl::bit_cast<T>(Ret);
5959
}
6060

6161
template <int N, typename T, access::address_space Space>
@@ -68,7 +68,7 @@ vec<T, N> load(const multi_ptr<T, Space> src) {
6868
VecT Ret =
6969
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));
7070

71-
return sycl::detail::bit_cast<typename vec<T, N>::vector_t>(Ret);
71+
return sycl::bit_cast<typename vec<T, N>::vector_t>(Ret);
7272
}
7373

7474
template <typename T, access::address_space Space>
@@ -77,7 +77,7 @@ void store(multi_ptr<T, Space> dst, const T &x) {
7777
using PtrT = sycl::detail::ConvertToOpenCLType_t<multi_ptr<BlockT, Space>>;
7878

7979
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
80-
sycl::detail::bit_cast<BlockT>(x));
80+
sycl::bit_cast<BlockT>(x));
8181
}
8282

8383
template <int N, typename T, access::address_space Space>
@@ -88,7 +88,7 @@ void store(multi_ptr<T, Space> dst, const vec<T, N> &x) {
8888
sycl::detail::ConvertToOpenCLType_t<const multi_ptr<BlockT, Space>>;
8989

9090
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
91-
sycl::detail::bit_cast<VecT>(x));
91+
sycl::bit_cast<VecT>(x));
9292
}
9393
#endif // __SYCL_DEVICE_ONLY__
9494

sycl/include/CL/sycl/atomic.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class atomic {
237237
Ptr);
238238
cl_int TmpVal = __spirv_AtomicLoad(
239239
TmpPtr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order));
240-
cl_float ResVal = detail::bit_cast<cl_float>(TmpVal);
240+
cl_float ResVal = bit_cast<cl_float>(TmpVal);
241241
return ResVal;
242242
}
243243
#else

sycl/include/CL/sycl/bit_cast.hpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//==---------------- bit_cast.hpp - SYCL bit_cast --------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <type_traits>
12+
13+
#if __cpp_lib_bit_cast
14+
#include <bit>
15+
#endif
16+
17+
__SYCL_INLINE_NAMESPACE(cl) {
18+
namespace sycl {
19+
20+
// forward decl
21+
namespace detail {
22+
inline void memcpy(void *Dst, const void *Src, std::size_t Size);
23+
}
24+
25+
// sycl::bit_cast ( no longer sycl::detail::bit_cast )
26+
template <typename To, typename From>
27+
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
28+
constexpr
29+
#endif
30+
To
31+
bit_cast(const From &from) noexcept {
32+
static_assert(sizeof(To) == sizeof(From),
33+
"Sizes of To and From must be equal");
34+
static_assert(std::is_trivially_copyable<From>::value,
35+
"From must be trivially copyable");
36+
static_assert(std::is_trivially_copyable<To>::value,
37+
"To must be trivially copyable");
38+
#if __cpp_lib_bit_cast
39+
return std::bit_cast<To>(from);
40+
#else // __cpp_lib_bit_cast
41+
42+
#if __has_builtin(__builtin_bit_cast)
43+
return __builtin_bit_cast(To, from);
44+
#else // __has_builtin(__builtin_bit_cast)
45+
static_assert(std::is_trivially_default_constructible<To>::value,
46+
"To must be trivially default constructible");
47+
To to;
48+
sycl::detail::memcpy(&to, &from, sizeof(To));
49+
return to;
50+
#endif // __has_builtin(__builtin_bit_cast)
51+
52+
#endif // __cpp_lib_bit_cast
53+
}
54+
55+
namespace detail {
56+
template <typename To, typename From>
57+
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
58+
constexpr
59+
#endif
60+
To
61+
bit_cast(const From &from) noexcept {
62+
return sycl::bit_cast<To>(from);
63+
}
64+
} // namespace detail
65+
66+
} // namespace sycl
67+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/helpers.hpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
#include <CL/sycl/detail/pi.hpp>
1717
#include <CL/sycl/detail/type_traits.hpp>
1818

19-
#if __cpp_lib_bit_cast
20-
#include <bit>
21-
#endif
2219
#include <memory>
2320
#include <stdexcept>
2421
#include <type_traits>
@@ -45,34 +42,6 @@ inline void memcpy(void *Dst, const void *Src, size_t Size) {
4542
}
4643
}
4744

48-
template <typename To, typename From>
49-
#if __cpp_lib_bit_cast || __has_builtin(__builtin_bit_cast)
50-
constexpr
51-
#endif
52-
To
53-
bit_cast(const From &from) noexcept {
54-
static_assert(sizeof(To) == sizeof(From),
55-
"Sizes of To and From must be equal");
56-
static_assert(std::is_trivially_copyable<From>::value,
57-
"From must be trivially copyable");
58-
static_assert(std::is_trivially_copyable<To>::value,
59-
"To must be trivially copyable");
60-
#if __cpp_lib_bit_cast
61-
return std::bit_cast<To>(from);
62-
#else // __cpp_lib_bit_cast
63-
64-
#if __has_builtin(__builtin_bit_cast)
65-
return __builtin_bit_cast(To, from);
66-
#else // __has_builtin(__builtin_bit_cast)
67-
static_assert(std::is_trivially_default_constructible<To>::value,
68-
"To must be trivially default constructible");
69-
To to;
70-
sycl::detail::memcpy(&to, &from, sizeof(To));
71-
return to;
72-
#endif // __has_builtin(__builtin_bit_cast)
73-
74-
#endif // __cpp_lib_bit_cast
75-
}
7645

7746
class context_impl;
7847
// The function returns list of events that can be passed to OpenCL API as
@@ -272,5 +241,6 @@ getSPIRVMemorySemanticsMask(const access::fence_space AccessSpace,
272241
}
273242

274243
} // namespace detail
244+
275245
} // namespace sycl
276246
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
140140
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
141141
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
142142
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
143-
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
143+
auto BroadcastX = bit_cast<BroadcastT>(x);
144144
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
145145
BroadcastT Result =
146146
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
147-
return detail::bit_cast<T>(Result);
147+
return bit_cast<T>(Result);
148148
}
149149
template <typename Group, typename T, typename IdT>
150150
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
@@ -190,11 +190,11 @@ EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
190190
for (int i = 0; i < Dimensions; ++i) {
191191
VecId[i] = local_id[Dimensions - i - 1];
192192
}
193-
auto BroadcastX = detail::bit_cast<BroadcastT>(x);
193+
auto BroadcastX = bit_cast<BroadcastT>(x);
194194
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
195195
BroadcastT Result =
196196
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
197-
return detail::bit_cast<T>(Result);
197+
return bit_cast<T>(Result);
198198
}
199199
template <typename Group, typename T, int Dimensions>
200200
EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
@@ -284,11 +284,11 @@ AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr,
284284
auto *PtrInt =
285285
reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
286286
MPtr.get());
287-
I DesiredInt = detail::bit_cast<I>(Desired);
288-
I ExpectedInt = detail::bit_cast<I>(Expected);
287+
I DesiredInt = bit_cast<I>(Desired);
288+
I ExpectedInt = bit_cast<I>(Expected);
289289
I ResultInt = __spirv_AtomicCompareExchange(
290290
PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
291-
return detail::bit_cast<T>(ResultInt);
291+
return bit_cast<T>(ResultInt);
292292
}
293293

294294
template <typename T, access::address_space AddressSpace>
@@ -312,7 +312,7 @@ AtomicLoad(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
312312
auto SPIRVOrder = getMemorySemanticsMask(Order);
313313
auto SPIRVScope = getScope(Scope);
314314
I ResultInt = __spirv_AtomicLoad(PtrInt, SPIRVScope, SPIRVOrder);
315-
return detail::bit_cast<T>(ResultInt);
315+
return bit_cast<T>(ResultInt);
316316
}
317317

318318
template <typename T, access::address_space AddressSpace>
@@ -335,7 +335,7 @@ AtomicStore(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
335335
MPtr.get());
336336
auto SPIRVOrder = getMemorySemanticsMask(Order);
337337
auto SPIRVScope = getScope(Scope);
338-
I ValueInt = detail::bit_cast<I>(Value);
338+
I ValueInt = bit_cast<I>(Value);
339339
__spirv_AtomicStore(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
340340
}
341341

@@ -359,10 +359,10 @@ AtomicExchange(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
359359
MPtr.get());
360360
auto SPIRVOrder = getMemorySemanticsMask(Order);
361361
auto SPIRVScope = getScope(Scope);
362-
I ValueInt = detail::bit_cast<I>(Value);
362+
I ValueInt = bit_cast<I>(Value);
363363
I ResultInt =
364364
__spirv_AtomicExchange(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
365-
return detail::bit_cast<T>(ResultInt);
365+
return bit_cast<T>(ResultInt);
366366
}
367367

368368
template <typename T, access::address_space AddressSpace>
@@ -600,57 +600,57 @@ using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
600600
template <typename T>
601601
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
602602
using ShuffleT = ConvertToNativeShuffleType_t<T>;
603-
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
603+
auto ShuffleX = bit_cast<ShuffleT>(x);
604604
#ifndef __NVPTX__
605605
ShuffleT Result = __spirv_SubgroupShuffleINTEL(
606606
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
607607
#else
608608
ShuffleT Result =
609609
__nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
610610
#endif
611-
return detail::bit_cast<T>(Result);
611+
return bit_cast<T>(Result);
612612
}
613613

614614
template <typename T>
615615
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
616616
using ShuffleT = ConvertToNativeShuffleType_t<T>;
617-
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
617+
auto ShuffleX = bit_cast<ShuffleT>(x);
618618
#ifndef __NVPTX__
619619
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
620620
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
621621
#else
622622
ShuffleT Result =
623623
__nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
624624
#endif
625-
return detail::bit_cast<T>(Result);
625+
return bit_cast<T>(Result);
626626
}
627627

628628
template <typename T>
629629
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
630630
using ShuffleT = ConvertToNativeShuffleType_t<T>;
631-
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
631+
auto ShuffleX = bit_cast<ShuffleT>(x);
632632
#ifndef __NVPTX__
633633
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(
634634
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
635635
#else
636636
ShuffleT Result =
637637
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
638638
#endif
639-
return detail::bit_cast<T>(Result);
639+
return bit_cast<T>(Result);
640640
}
641641

642642
template <typename T>
643643
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
644644
using ShuffleT = ConvertToNativeShuffleType_t<T>;
645-
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
645+
auto ShuffleX = bit_cast<ShuffleT>(x);
646646
#ifndef __NVPTX__
647647
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(
648648
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
649649
#else
650650
ShuffleT Result =
651651
__nvvm_shfl_sync_up_i32(membermask(), ShuffleX, local_id.get(0), 0);
652652
#endif
653-
return detail::bit_cast<T>(Result);
653+
return bit_cast<T>(Result);
654654
}
655655

656656
// Generic shuffles may require multiple calls to SubgroupShuffle

sycl/include/CL/sycl/stl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
// 4.5 C++ Standard library classes required for the interface
1212

13+
#include <CL/sycl/bit_cast.hpp>
1314
#include <CL/sycl/detail/defines.hpp>
1415

1516
#include <exception>

sycl/test/bit_cast/bit_cast.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ To doBitCast(const From &ValueToConvert) {
2121
Queue.submit([&](sycl::handler &cgh) {
2222
auto acc = Buf.template get_access<sycl_write>(cgh);
2323
cgh.single_task<class BitCastKernel<To, From>>([=]() {
24-
// TODO: change to sycl::bit_cast in the future
25-
acc[0] = sycl::detail::bit_cast<To>(ValueToConvert);
24+
acc[0] = sycl::bit_cast<To>(ValueToConvert);
2625
});
2726
});
2827
}

0 commit comments

Comments
 (0)