diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index d08f4bc9ae614..68050ed0d9448 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -510,6 +510,9 @@ class MarkDeviceFunction : public RecursiveASTVisitor { FunctionDecl *FD = WorkList.back().first; FunctionDecl *ParentFD = WorkList.back().second; + if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction(FD)) { + KernelBody = FD; + } if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction(FD)) { assert(!KernelBody && "inconsistent call graph - only one kernel body " "function can be called"); diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index 890fa1041293d..d86285cf7cd8f 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -728,23 +728,58 @@ class __SYCL_EXPORT handler { void parallel_for_lambda_impl(range NumWorkItems, KernelType KernelFunc) { throwIfActionIsCreated(); - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; + + // If 1D kernel argument is an integral type, convert it to sycl::item<1> using TransformedArgType = typename std::conditional::value && Dims == 1, item, LambdaArgType>::type; + using NameT = + typename detail::get_kernel_name_t::name; + constexpr size_t GoodLocalSizeX = 32; + std::string KName = typeid(NameT *).name(); + bool DisableRounding = + KName.find("SYCL_OPT_PFWGS_DISABLE") != std::string::npos; + if (!DisableRounding && NumWorkItems[0] % GoodLocalSizeX != 0) { + // Not a multiple + size_t NewValX = + ((NumWorkItems[0] + GoodLocalSizeX - 1) / GoodLocalSizeX) * + GoodLocalSizeX; + if (getenv("SYCL_OPT_PFWGS_TRACE") != nullptr) + std::cerr << "***** Adjusted size from " << NumWorkItems[0] << " to " + << NewValX << " *****\n"; + auto Wrapper = [=](TransformedArgType Arg) { + if (Arg[0] >= NumWorkItems[0]) + return; + Arg.set_allowed_range(NumWorkItems); + KernelFunc(Arg); + }; + + using NameWT = NameT *; + range AdjustedRange = NumWorkItems; + AdjustedRange.set_range(NewValX); #ifdef __SYCL_DEVICE_ONLY__ - (void)NumWorkItems; - kernel_parallel_for(KernelFunc); + kernel_parallel_for(Wrapper); #else - detail::checkValueRange(NumWorkItems); - MNDRDesc.set(std::move(NumWorkItems)); - StoreLambda( - std::move(KernelFunc)); - MCGType = detail::CG::KERNEL; + detail::checkValueRange(AdjustedRange); + MNDRDesc.set(std::move(AdjustedRange)); + StoreLambda( + std::move(Wrapper)); + MCGType = detail::CG::KERNEL; #endif + } else { +#ifdef __SYCL_DEVICE_ONLY__ + (void)NumWorkItems; + kernel_parallel_for(KernelFunc); +#else + detail::checkValueRange(NumWorkItems); + MNDRDesc.set(std::move(NumWorkItems)); + StoreLambda( + std::move(KernelFunc)); + MCGType = detail::CG::KERNEL; +#endif + } } /// Defines and invokes a SYCL kernel function for the specified range. diff --git a/sycl/include/CL/sycl/id.hpp b/sycl/include/CL/sycl/id.hpp index 16d176b8b698d..79f35aca19c40 100644 --- a/sycl/include/CL/sycl/id.hpp +++ b/sycl/include/CL/sycl/id.hpp @@ -94,6 +94,8 @@ template class id : public detail::array { return result; } + void set_allowed_range(range rnwi) { (void)rnwi[0]; } + #ifndef __SYCL_DISABLE_ID_TO_INT_CONV__ /* Template operator is not allowed because it disables further type * conversion. For example, the next code will not work in case of template diff --git a/sycl/include/CL/sycl/item.hpp b/sycl/include/CL/sycl/item.hpp index 9d9a879815294..a5409794dfdcf 100644 --- a/sycl/include/CL/sycl/item.hpp +++ b/sycl/include/CL/sycl/item.hpp @@ -104,6 +104,8 @@ template class item { bool operator!=(const item &rhs) const { return rhs.MImpl != MImpl; } + void set_allowed_range(const range rnwi) { MImpl.MExtent = rnwi; } + protected: template item(detail::enable_if_t> &extent, diff --git a/sycl/include/CL/sycl/range.hpp b/sycl/include/CL/sycl/range.hpp index 2745a05667a31..301e5a1979095 100644 --- a/sycl/include/CL/sycl/range.hpp +++ b/sycl/include/CL/sycl/range.hpp @@ -62,6 +62,11 @@ template class range : public detail::array { return size; } + // Adjust only the first dim of the range + void set_range(const size_t dim0) { + this->common_array[0] = dim0; + } + range(const range &rhs) = default; range(range &&rhs) = default; range &operator=(const range &rhs) = default;