Skip to content

[NFCI][SYCL] Support multi_ptr in convertToOpenCLType #12693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,14 @@ template <> struct ConvertToOpenCLTypeImpl<Boolean<1>> {
// Or should it be "int"?
using type = Boolean<1>;
};
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
// TODO: It seems we only use this to convert a pointer's element type. As such,
// although it doesn't look very clean, it should be ok having this case handled
// explicitly until further refactoring of this area.
template <> struct ConvertToOpenCLTypeImpl<std::byte> {
using type = uint8_t;
};
#endif
#endif

template <typename T> struct ConvertToOpenCLTypeImpl<T *> {
Expand Down Expand Up @@ -700,8 +708,30 @@ convertDataToType(FROM t) {
// Now fuse the above into a simpler helper that's easy to use.
// TODO: That should probably be moved outside of "type_traits".
template <typename T> auto convertToOpenCLType(T &&x) {
using OpenCLType = ConvertToOpenCLType_t<std::remove_reference_t<T>>;
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
using no_ref = std::remove_reference_t<T>;
if constexpr (is_multi_ptr_v<no_ref>) {
return convertToOpenCLType(x.get_decorated());
} else if constexpr (std::is_pointer_v<no_ref>) {
// TODO: Below ignores volatile, but we didn't have a need for it yet.
using elem_type = remove_decoration_t<std::remove_pointer_t<no_ref>>;
using converted_elem_type_no_cv =
ConvertToOpenCLType_t<std::remove_const_t<elem_type>>;
using converted_elem_type =
std::conditional_t<std::is_const_v<elem_type>,
const converted_elem_type_no_cv,
converted_elem_type_no_cv>;
#ifdef __SYCL_DEVICE_ONLY__
using result_type =
typename DecoratedType<converted_elem_type,
deduce_AS<no_ref>::value>::type *;
#else
using result_type = converted_elem_type *;
#endif
return reinterpret_cast<result_type>(x);
} else {
using OpenCLType = ConvertToOpenCLType_t<no_ref>;
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
}
}

template <typename To, typename From> auto convertFromOpenCLTypeFor(From &&x) {
Expand Down
28 changes: 8 additions & 20 deletions sycl/include/sycl/group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
global_ptr<dataT> src,
size_t numElements,
size_t srcStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, srcStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, srcStride, 0);
return device_event(E);
}

Expand All @@ -337,12 +334,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
size_t numElements,
size_t destStride)
const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, destStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, destStride, 0);
return device_event(E);
}

Expand All @@ -359,12 +353,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
async_work_group_copy(decorated_local_ptr<DestDataT> dest,
decorated_global_ptr<SrcDataT> src, size_t numElements,
size_t srcStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, srcStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, srcStride, 0);
return device_event(E);
}

Expand All @@ -381,12 +372,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
async_work_group_copy(decorated_global_ptr<DestDataT> dest,
decorated_local_ptr<SrcDataT> src, size_t numElements,
size_t destStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, destStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, destStride, 0);
return device_event(E);
}

Expand Down
38 changes: 22 additions & 16 deletions sycl/include/sycl/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ namespace sub_group {
template <typename T>
using SelectBlockT = select_cl_scalar_integral_unsigned_t<T>;

template <typename MultiPtrTy> auto convertToBlockPtr(MultiPtrTy MultiPtr) {
static_assert(is_multi_ptr_v<MultiPtrTy>);
auto DecoratedPtr = convertToOpenCLType(MultiPtr);
using DecoratedPtrTy = decltype(DecoratedPtr);
using ElemTy = remove_decoration_t<std::remove_pointer_t<DecoratedPtrTy>>;

using TargetElemTy = SelectBlockT<ElemTy>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ElemTy has cv qualifier, it looks like SelectBlockT is going to drop that qualifier.
Maybe my understanding is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but I don't think it's a regression caused by this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
NIT: In that case it will be nice to have some todo maybe, otherwise I guess it will be forgotten.

// TODO: Handle cv qualifiers.
#ifdef __SYCL_DEVICE_ONLY__
using ResultTy =
typename DecoratedType<TargetElemTy,
deduce_AS<DecoratedPtrTy>::value>::type *;
#else
using ResultTy = TargetElemTy *;
#endif
return reinterpret_cast<ResultTy>(DecoratedPtr);
}

template <typename T, access::address_space Space>
using AcceptableForGlobalLoadStore =
std::bool_constant<!std::is_same_v<void, SelectBlockT<T>> &&
Expand All @@ -57,11 +75,7 @@ template <typename T, access::address_space Space,
access::decorated DecorateAddress>
T load(const multi_ptr<T, Space, DecorateAddress> src) {
using BlockT = SelectBlockT<T>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

BlockT Ret =
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));
BlockT Ret = __spirv_SubgroupBlockReadINTEL<BlockT>(convertToBlockPtr(src));

return sycl::bit_cast<T>(Ret);
}
Expand All @@ -71,11 +85,7 @@ template <int N, typename T, access::address_space Space,
vec<T, N> load(const multi_ptr<T, Space, DecorateAddress> src) {
using BlockT = SelectBlockT<T>;
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

VecT Ret =
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));
VecT Ret = __spirv_SubgroupBlockReadINTEL<VecT>(convertToBlockPtr(src));

return sycl::bit_cast<typename vec<T, N>::vector_t>(Ret);
}
Expand All @@ -84,10 +94,8 @@ template <typename T, access::address_space Space,
access::decorated DecorateAddress>
void store(multi_ptr<T, Space, DecorateAddress> dst, const T &x) {
using BlockT = SelectBlockT<T>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
multi_ptr<BlockT, Space, DecorateAddress>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
sycl::bit_cast<BlockT>(x));
}

Expand All @@ -96,10 +104,8 @@ template <int N, typename T, access::address_space Space,
void store(multi_ptr<T, Space, DecorateAddress> dst, const vec<T, N> &x) {
using BlockT = SelectBlockT<T>;
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
sycl::bit_cast<VecT>(x));
}
#endif // __SYCL_DEVICE_ONLY__
Expand Down