Skip to content

Commit a0bfab1

Browse files
[SYCL] Fix reductions regressions after #6343 (#6460)
Three pieces: * Access to private static handler::withAuxHandler * Pow2WG -> HasUniformWG * Out accessor must be the last created (after other temps)
1 parent a364383 commit a0bfab1

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ namespace oneapi {
3030

3131
namespace detail {
3232

33+
template <class FunctorTy>
34+
event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
35+
FunctorTy Func) {
36+
handler AuxHandler(Queue, IsHost);
37+
Func(AuxHandler);
38+
return AuxHandler.finalize();
39+
}
40+
3341
using cl::sycl::detail::bool_constant;
3442
using cl::sycl::detail::enable_if_t;
3543
using cl::sycl::detail::queue_impl;
@@ -2434,6 +2442,7 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24342442

24352443
bool Pow2WG = (WGSize & (WGSize - 1)) == 0;
24362444
bool IsOneWG = NWorkGroups == 1;
2445+
bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
24372446

24382447
// Like reduCGFuncImpl, we also have to split out scalar and array reductions
24392448
IsScalarReduction ScalarPredicate;
@@ -2442,28 +2451,27 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24422451
IsArrayReduction ArrayPredicate;
24432452
auto ArrayIs = filterSequence<Reductions...>(ArrayPredicate, ReduIndices);
24442453

2454+
size_t LocalAccSize = WGSize + (HasUniformWG ? 0 : 1);
2455+
auto LocalAccsTuple =
2456+
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2457+
auto InAccsTuple =
2458+
getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices);
2459+
2460+
auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices);
2461+
auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices);
2462+
auto InitToIdentityProps =
2463+
getInitToIdentityProperties(ReduTuple, ReduIndices);
2464+
24452465
// Predicate/OutAccsTuple below have different type depending on us having
24462466
// just a single WG or multiple WGs. Use this lambda to avoid code
24472467
// duplication.
24482468
auto Rest = [&](auto Predicate, auto OutAccsTuple) {
24492469
auto AccReduIndices = filterSequence<Reductions...>(Predicate, ReduIndices);
24502470
associateReduAccsWithHandler(CGH, ReduTuple, AccReduIndices);
2451-
2452-
size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1);
2453-
auto LocalAccsTuple =
2454-
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2455-
auto InAccsTuple =
2456-
getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices);
2457-
2458-
auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices);
2459-
auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices);
2460-
auto InitToIdentityProps =
2461-
getInitToIdentityProperties(ReduTuple, ReduIndices);
2462-
24632471
using Name = __sycl_reduction_kernel<reduction::aux_krn::Multi, KernelName,
24642472
decltype(OutAccsTuple)>;
24652473
// TODO: Opportunity to parallelize across number of elements
2466-
range<1> GlobalRange = {Pow2WG ? NWorkItems : NWorkGroups * WGSize};
2474+
range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
24672475
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
24682476
CGH.parallel_for<Name>(Range, [=](nd_item<1> NDIt) {
24692477
size_t WGSize = NDIt.get_local_range().size();
@@ -2472,12 +2480,12 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24722480

24732481
// Handle scalar and array reductions
24742482
reduAuxCGFuncImplScalar<Reductions...>(
2475-
Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple,
2476-
InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2483+
HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2484+
LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
24772485
InitToIdentityProps, ScalarIs);
24782486
reduAuxCGFuncImplArray<Reductions...>(
2479-
Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple,
2480-
InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2487+
HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2488+
LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
24812489
InitToIdentityProps, ArrayIs);
24822490
});
24832491
};
@@ -2504,7 +2512,7 @@ void reduSaveFinalResultToUserMemHelper(
25042512
if constexpr (!Reduction::is_usm) {
25052513
if (Redu.hasUserDiscardWriteAccessor()) {
25062514
event CopyEvent =
2507-
handler::withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
2515+
withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
25082516
auto InAcc = Redu.getReadAccToPreviousPartialReds(CopyHandler);
25092517
auto OutAcc = Redu.getUserDiscardWriteAccessor();
25102518
Redu.associateWithHandler(CopyHandler);

sycl/include/sycl/handler.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>);
315315

316316
template <typename FirstT, typename... RestT> struct AreAllButLastReductions;
317317

318+
template <class FunctorTy>
319+
event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
320+
FunctorTy Func);
318321
} // namespace detail
319322
} // namespace oneapi
320323
} // namespace ext
@@ -476,12 +479,9 @@ class __SYCL_EXPORT handler {
476479
}
477480

478481
template <class FunctorTy>
479-
static event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue,
480-
bool IsHost, FunctorTy Func) {
481-
handler AuxHandler(Queue, IsHost);
482-
Func(AuxHandler);
483-
return AuxHandler.finalize();
484-
}
482+
friend event
483+
ext::oneapi::detail::withAuxHandler(std::shared_ptr<detail::queue_impl> Queue,
484+
bool IsHost, FunctorTy Func);
485485
/// }@
486486

487487
/// Saves buffers created by handling reduction feature in handler.

0 commit comments

Comments
 (0)