diff --git a/libclc/generic/include/spirv/spirv.h b/libclc/generic/include/spirv/spirv.h index 7449d03409753..8c6878962ad88 100644 --- a/libclc/generic/include/spirv/spirv.h +++ b/libclc/generic/include/spirv/spirv.h @@ -37,13 +37,18 @@ #include /* 6.11.1 Work-Item Functions */ -#include #include -#include +#include +#include +#include #include +#include +#include #include -#include -#include +#include +#include +#include +#include #include /* 6.11.2.1 Floating-point macros */ diff --git a/libclc/generic/include/spirv/workitem/get_max_sub_group_size.h b/libclc/generic/include/spirv/workitem/get_max_sub_group_size.h new file mode 100644 index 0000000000000..3befd9abae240 --- /dev/null +++ b/libclc/generic/include/spirv/workitem/get_max_sub_group_size.h @@ -0,0 +1,9 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize(); diff --git a/libclc/generic/include/spirv/workitem/get_num_sub_groups.h b/libclc/generic/include/spirv/workitem/get_num_sub_groups.h new file mode 100644 index 0000000000000..c6341dd6c63f4 --- /dev/null +++ b/libclc/generic/include/spirv/workitem/get_num_sub_groups.h @@ -0,0 +1,9 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups(); diff --git a/libclc/generic/include/spirv/workitem/get_sub_group_id.h b/libclc/generic/include/spirv/workitem/get_sub_group_id.h new file mode 100644 index 0000000000000..47f4e0c5afa7a --- /dev/null +++ b/libclc/generic/include/spirv/workitem/get_sub_group_id.h @@ -0,0 +1,9 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId(); diff --git a/libclc/generic/include/spirv/workitem/get_sub_group_local_id.h b/libclc/generic/include/spirv/workitem/get_sub_group_local_id.h new file mode 100644 index 0000000000000..f69319bcfad5a --- /dev/null +++ b/libclc/generic/include/spirv/workitem/get_sub_group_local_id.h @@ -0,0 +1,9 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId(); diff --git a/libclc/generic/include/spirv/workitem/get_sub_group_size.h b/libclc/generic/include/spirv/workitem/get_sub_group_size.h new file mode 100644 index 0000000000000..59066301ce6a1 --- /dev/null +++ b/libclc/generic/include/spirv/workitem/get_sub_group_size.h @@ -0,0 +1,9 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize(); diff --git a/libclc/ptx-nvidiacl/libspirv/SOURCES b/libclc/ptx-nvidiacl/libspirv/SOURCES index 7c710eafe79f6..16527d745568a 100644 --- a/libclc/ptx-nvidiacl/libspirv/SOURCES +++ b/libclc/ptx-nvidiacl/libspirv/SOURCES @@ -75,6 +75,11 @@ workitem/get_global_size.cl workitem/get_group_id.cl workitem/get_local_id.cl workitem/get_local_size.cl +workitem/get_max_sub_group_size.cl workitem/get_num_groups.cl +workitem/get_num_sub_groups.cl +workitem/get_sub_group_id.cl +workitem/get_sub_group_local_id.cl +workitem/get_sub_group_size.cl images/image_helpers.ll images/image.cl diff --git a/libclc/ptx-nvidiacl/libspirv/workitem/get_max_sub_group_size.cl b/libclc/ptx-nvidiacl/libspirv/workitem/get_max_sub_group_size.cl new file mode 100644 index 0000000000000..960863d04c5de --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/workitem/get_max_sub_group_size.cl @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() { + return 32; + // FIXME: warpsize is defined by NVVM IR but doesn't compile if used here + // return __nvvm_read_ptx_sreg_warpsize(); +} diff --git a/libclc/ptx-nvidiacl/libspirv/workitem/get_num_sub_groups.cl b/libclc/ptx-nvidiacl/libspirv/workitem/get_num_sub_groups.cl new file mode 100644 index 0000000000000..5cc447c17671e --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/workitem/get_num_sub_groups.cl @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups() { + // sreg.nwarpid returns number of warp identifiers, not number of warps + // see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + size_t size_x = __spirv_WorkgroupSize_x(); + size_t size_y = __spirv_WorkgroupSize_y(); + size_t size_z = __spirv_WorkgroupSize_z(); + uint sg_size = __spirv_SubgroupMaxSize(); + uint linear_size = size_z * size_y * size_x; + return (linear_size + sg_size - 1) / sg_size; +} diff --git a/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_id.cl b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_id.cl new file mode 100644 index 0000000000000..a22200eca59ee --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_id.cl @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId() { + // sreg.warpid is volatile and doesn't represent virtual warp index + // see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + size_t id_x = __spirv_LocalInvocationId_x(); + size_t id_y = __spirv_LocalInvocationId_y(); + size_t id_z = __spirv_LocalInvocationId_z(); + size_t size_x = __spirv_WorkgroupSize_x(); + size_t size_y = __spirv_WorkgroupSize_y(); + size_t size_z = __spirv_WorkgroupSize_z(); + uint sg_size = __spirv_SubgroupMaxSize(); + return (id_z * size_y * size_x + id_y * size_x + id_x) / sg_size; +} diff --git a/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_local_id.cl b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_local_id.cl new file mode 100644 index 0000000000000..3eafe2d071299 --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_local_id.cl @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId() { + return __nvvm_read_ptx_sreg_laneid(); +} diff --git a/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_size.cl b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_size.cl new file mode 100644 index 0000000000000..99f7c67fc02dc --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_size.cl @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize() { + if (__spirv_SubgroupId() != __spirv_NumSubgroups() - 1) { + return __spirv_SubgroupMaxSize(); + } + size_t size_x = __spirv_WorkgroupSize_x(); + size_t size_y = __spirv_WorkgroupSize_y(); + size_t size_z = __spirv_WorkgroupSize_z(); + uint linear_size = size_z * size_y * size_x; + uint uniform_groups = __spirv_NumSubgroups() - 1; + uint uniform_size = __spirv_SubgroupMaxSize() * uniform_groups; + return linear_size - uniform_size; +} diff --git a/sycl/include/CL/__spirv/spirv_vars.hpp b/sycl/include/CL/__spirv/spirv_vars.hpp index e25f0de9fee65..53315cd9f72eb 100644 --- a/sycl/include/CL/__spirv/spirv_vars.hpp +++ b/sycl/include/CL/__spirv/spirv_vars.hpp @@ -45,6 +45,12 @@ SYCL_EXTERNAL size_t __spirv_LocalInvocationId_x(); SYCL_EXTERNAL size_t __spirv_LocalInvocationId_y(); SYCL_EXTERNAL size_t __spirv_LocalInvocationId_z(); +SYCL_EXTERNAL uint32_t __spirv_SubgroupSize(); +SYCL_EXTERNAL uint32_t __spirv_SubgroupMaxSize(); +SYCL_EXTERNAL uint32_t __spirv_NumSubgroups(); +SYCL_EXTERNAL uint32_t __spirv_SubgroupId(); +SYCL_EXTERNAL uint32_t __spirv_SubgroupLocalInvocationId(); + #else // __SYCL_NVPTX__ typedef size_t size_t_vec __attribute__((ext_vector_type(3))); @@ -56,6 +62,12 @@ __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInLocalInvocationId; __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInWorkgroupId; __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInGlobalOffset; +__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize; +__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize; +__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups; +__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId; +__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId; + SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() { return __spirv_BuiltInGlobalInvocationId.x; } @@ -126,14 +138,23 @@ SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_z() { return __spirv_BuiltInLocalInvocationId.z; } -#endif // __SYCL_NVPTX__ +SYCL_EXTERNAL inline uint32_t __spirv_SubgroupSize() { + return __spirv_BuiltInSubgroupSize; +} +SYCL_EXTERNAL inline uint32_t __spirv_SubgroupMaxSize() { + return __spirv_BuiltInSubgroupMaxSize; +} +SYCL_EXTERNAL inline uint32_t __spirv_NumSubgroups() { + return __spirv_BuiltInNumSubgroups; +} +SYCL_EXTERNAL inline uint32_t __spirv_SubgroupId() { + return __spirv_BuiltInSubgroupId; +} +SYCL_EXTERNAL inline uint32_t __spirv_SubgroupLocalInvocationId() { + return __spirv_BuiltInSubgroupLocalInvocationId; +} -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize; -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize; -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups; -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumEnqueuedSubgroups; -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId; -__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId; +#endif // __SYCL_NVPTX__ #undef __SPIRV_VAR_QUALIFIERS diff --git a/sycl/include/CL/sycl/ONEAPI/sub_group.hpp b/sycl/include/CL/sycl/ONEAPI/sub_group.hpp index 68b61dbbb1ae0..bdf63e15d6552 100644 --- a/sycl/include/CL/sycl/ONEAPI/sub_group.hpp +++ b/sycl/include/CL/sycl/ONEAPI/sub_group.hpp @@ -109,7 +109,7 @@ struct sub_group { id_type get_local_id() const { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_BuiltInSubgroupLocalInvocationId; + return __spirv_SubgroupLocalInvocationId(); #else throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); @@ -127,7 +127,7 @@ struct sub_group { range_type get_local_range() const { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_BuiltInSubgroupSize; + return __spirv_SubgroupSize(); #else throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); @@ -136,7 +136,7 @@ struct sub_group { range_type get_max_local_range() const { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_BuiltInSubgroupMaxSize; + return __spirv_SubgroupMaxSize(); #else throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); @@ -145,7 +145,7 @@ struct sub_group { id_type get_group_id() const { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_BuiltInSubgroupId; + return __spirv_SubgroupId(); #else throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); @@ -163,7 +163,7 @@ struct sub_group { range_type get_group_range() const { #ifdef __SYCL_DEVICE_ONLY__ - return __spirv_BuiltInNumSubgroups; + return __spirv_NumSubgroups(); #else throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); diff --git a/sycl/test/sub_group/common.cpp b/sycl/test/sub_group/common.cpp index 41623ae2c228b..3bbdc9832fe82 100644 --- a/sycl/test/sub_group/common.cpp +++ b/sycl/test/sub_group/common.cpp @@ -1,6 +1,3 @@ -// UNSUPPORTED: cuda -// CUDA compilation and runtime do not yet support sub-groups. -// // RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out // RUN: env SYCL_DEVICE_TYPE=HOST %t.out // RUN: %CPU_RUN_PLACEHOLDER %t.out @@ -70,7 +67,7 @@ void check(queue &Queue, unsigned int G, unsigned int L) { } int main() { queue Queue; - if (!core_sg_supported(Queue.get_device())) { + if (Queue.get_device().is_host()) { std::cout << "Skipping test\n"; return 0; }