Skip to content

Commit ee479d8

Browse files
committed
Fix bugs exposed in CI
- implement fallback for CUDA BE; - fix assert function name on Windows.
1 parent 8923e56 commit ee479d8

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ struct sub_group {
241241
std::is_same<typename detail::remove_AS<T>::type, T>::value, T>
242242
load(T *src) const {
243243

244+
#ifdef __NVPTX__
245+
return src[get_local_id()[0]];
246+
#else // __NVPTX__
244247
auto l = __spirv_GenericCastToPtrExplicit_ToLocal<T>(
245248
src, __spv::StorageClass::Workgroup);
246249
if (l)
@@ -258,6 +261,7 @@ struct sub_group {
258261

259262
// Fallback for other address spaces to be mapped to global
260263
return load(__spirv_PtrCastToGeneric<T>(src));
264+
#endif // __NVPTX__
261265
}
262266
#else //__SYCL_DEVICE_ONLY__
263267
template <typename T> T load(T *src) const {
@@ -375,6 +379,9 @@ struct sub_group {
375379
std::is_same<typename detail::remove_AS<T>::type, T>::value>
376380
store(T *dst, const typename detail::remove_AS<T>::type &x) const {
377381

382+
#ifdef __NVPTX__
383+
dst[get_local_id()[0]] = x;
384+
#else // __NVPTX__
378385
auto l = __spirv_GenericCastToPtrExplicit_ToLocal<T>(
379386
dst, __spv::StorageClass::Workgroup);
380387
if (l) {
@@ -396,6 +403,7 @@ struct sub_group {
396403

397404
// Fallback for other address spaces to be mapped to global
398405
store(__spirv_PtrCastToGeneric<T>(dst), x);
406+
#endif // __NVPTX__
399407
}
400408
#else //__SYCL_DEVICE_ONLY__
401409
template <typename T> void store(T *dst, const T &x) const {

sycl/test/extensions/sub_group_as.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int main(int argc, char *argv[]) {
6161
// CHECK: call spir_func i8 addrspace(1)* @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvN5__spv12StorageClass4FlagE(i8 addrspace(4)*
6262
// CHECK: call spir_func i32 @_Z30__spirv_SubgroupBlockReadINTELIjET_PU3AS1Kj(i32 addrspace(1)*
6363
// CHECK: call spir_func i8* @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePKvN5__spv12StorageClass4FlagE(i8 addrspace(4)*
64-
// CHECK: call spir_func void @__assert_fail
64+
// CHECK: call spir_func void {{.*}}assert
6565
// CHECK: call spir_func i8 addrspace(4)* @_Z24__spirv_PtrCastToGenericPKv(i8 addrspace(4)*
6666
// CHECK: call spir_func i32 @_Z30__spirv_SubgroupBlockReadINTELIjET_PU3AS1Kj(i32 addrspace(1)*
6767
// Global address space
@@ -78,7 +78,7 @@ int main(int argc, char *argv[]) {
7878
// CHECK: call spir_func i8 addrspace(1)* @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvN5__spv12StorageClass4FlagE(i8 addrspace(4)*
7979
// CHECK: call spir_func void @_Z31__spirv_SubgroupBlockWriteINTELIjEvPU3AS1jT_(i32 addrspace(1)*
8080
// CHECK: call spir_func i8* @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePKvN5__spv12StorageClass4FlagE(i8 addrspace(4)*
81-
// CHECK: call spir_func void @__assert_fail
81+
// CHECK: call spir_func void {{.*}}assert
8282
// CHECK: call spir_func i8 addrspace(4)* @_Z24__spirv_PtrCastToGenericPKv(i8 addrspace(4)*
8383
// CHECK: call spir_func void @_Z31__spirv_SubgroupBlockWriteINTELIjEvPU3AS1jT_(i32 addrspace(1)*
8484
sg.store(&global[i], x + y + z);

0 commit comments

Comments
 (0)