diff --git a/sycl/include/sycl/item.hpp b/sycl/include/sycl/item.hpp index 01fca65cdf8b5..100250567ab03 100644 --- a/sycl/include/sycl/item.hpp +++ b/sycl/include/sycl/item.hpp @@ -25,9 +25,12 @@ template class RoundedRangeKernel; template class RoundedRangeKernelWithKH; + +namespace reduction { +template +item getDelinearizedItem(range Range, id Id); +} // namespace reduction } // namespace detail -template class id; -template class range; /// Identifies an instance of the function object executing at each point /// in a range. @@ -130,6 +133,10 @@ template class item { friend class detail::RoundedRangeKernelWithKH; void set_allowed_range(const range rnwi) { MImpl.MExtent = rnwi; } + template + friend item + detail::reduction::getDelinearizedItem(range Range, id Id); + detail::ItemBase MImpl; }; diff --git a/sycl/include/sycl/reduction.hpp b/sycl/include/sycl/reduction.hpp index e2bf27cef89e3..803e65d36868f 100644 --- a/sycl/include/sycl/reduction.hpp +++ b/sycl/include/sycl/reduction.hpp @@ -2369,8 +2369,18 @@ void reduction_parallel_for(handler &CGH, range Range, size_t Start = GroupStart + NDId.get_local_id(0); size_t End = GroupEnd; size_t Stride = NDId.get_local_range(0); + auto GetDelinearized = [&](size_t I) { + auto Id = getDelinearizedId(Range, I); + if constexpr (std::is_invocable_v, + decltype(Reducers)...>) + return Id; + else + // SYCL doesn't provide parallel_for accepting offset in presence of + // reductions, so use with_offset==false. + return reduction::getDelinearizedItem(Range, Id); + }; for (size_t I = Start; I < End; I += Stride) - KernelFunc(getDelinearizedId(Range, I), Reducers...); + KernelFunc(GetDelinearized(I), Reducers...); }; if constexpr (NumArgs == 2) { using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>; diff --git a/sycl/include/sycl/reduction_forward.hpp b/sycl/include/sycl/reduction_forward.hpp index 17bd1dfb65bca..1fde34f3fbca9 100644 --- a/sycl/include/sycl/reduction_forward.hpp +++ b/sycl/include/sycl/reduction_forward.hpp @@ -42,6 +42,11 @@ enum class strategy : int { // are limited to those below. inline void finalizeHandler(handler &CGH); template void withAuxHandler(handler &CGH, FunctorTy Func); + +template +item getDelinearizedItem(range Range, id Id) { + return {Range, Id}; +} } // namespace reduction template