@@ -30,6 +30,14 @@ namespace oneapi {
30
30
31
31
namespace detail {
32
32
33
+ template <class FunctorTy >
34
+ event withAuxHandler (std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
35
+ FunctorTy Func) {
36
+ handler AuxHandler (Queue, IsHost);
37
+ Func (AuxHandler);
38
+ return AuxHandler.finalize ();
39
+ }
40
+
33
41
using cl::sycl::detail::bool_constant;
34
42
using cl::sycl::detail::enable_if_t ;
35
43
using cl::sycl::detail::queue_impl;
@@ -2434,6 +2442,7 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
2434
2442
2435
2443
bool Pow2WG = (WGSize & (WGSize - 1 )) == 0 ;
2436
2444
bool IsOneWG = NWorkGroups == 1 ;
2445
+ bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
2437
2446
2438
2447
// Like reduCGFuncImpl, we also have to split out scalar and array reductions
2439
2448
IsScalarReduction ScalarPredicate;
@@ -2442,28 +2451,27 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
2442
2451
IsArrayReduction ArrayPredicate;
2443
2452
auto ArrayIs = filterSequence<Reductions...>(ArrayPredicate, ReduIndices);
2444
2453
2454
+ size_t LocalAccSize = WGSize + (HasUniformWG ? 0 : 1 );
2455
+ auto LocalAccsTuple =
2456
+ createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2457
+ auto InAccsTuple =
2458
+ getReadAccsToPreviousPartialReds (CGH, ReduTuple, ReduIndices);
2459
+
2460
+ auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
2461
+ auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
2462
+ auto InitToIdentityProps =
2463
+ getInitToIdentityProperties (ReduTuple, ReduIndices);
2464
+
2445
2465
// Predicate/OutAccsTuple below have different type depending on us having
2446
2466
// just a single WG or multiple WGs. Use this lambda to avoid code
2447
2467
// duplication.
2448
2468
auto Rest = [&](auto Predicate, auto OutAccsTuple) {
2449
2469
auto AccReduIndices = filterSequence<Reductions...>(Predicate, ReduIndices);
2450
2470
associateReduAccsWithHandler (CGH, ReduTuple, AccReduIndices);
2451
-
2452
- size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1 );
2453
- auto LocalAccsTuple =
2454
- createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2455
- auto InAccsTuple =
2456
- getReadAccsToPreviousPartialReds (CGH, ReduTuple, ReduIndices);
2457
-
2458
- auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
2459
- auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
2460
- auto InitToIdentityProps =
2461
- getInitToIdentityProperties (ReduTuple, ReduIndices);
2462
-
2463
2471
using Name = __sycl_reduction_kernel<reduction::aux_krn::Multi, KernelName,
2464
2472
decltype (OutAccsTuple)>;
2465
2473
// TODO: Opportunity to parallelize across number of elements
2466
- range<1 > GlobalRange = {Pow2WG ? NWorkItems : NWorkGroups * WGSize};
2474
+ range<1 > GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
2467
2475
nd_range<1 > Range{GlobalRange, range<1 >(WGSize)};
2468
2476
CGH.parallel_for <Name>(Range, [=](nd_item<1 > NDIt) {
2469
2477
size_t WGSize = NDIt.get_local_range ().size ();
@@ -2472,12 +2480,12 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
2472
2480
2473
2481
// Handle scalar and array reductions
2474
2482
reduAuxCGFuncImplScalar<Reductions...>(
2475
- Pow2WG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple ,
2476
- InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2483
+ HasUniformWG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2484
+ LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2477
2485
InitToIdentityProps, ScalarIs);
2478
2486
reduAuxCGFuncImplArray<Reductions...>(
2479
- Pow2WG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple ,
2480
- InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2487
+ HasUniformWG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2488
+ LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2481
2489
InitToIdentityProps, ArrayIs);
2482
2490
});
2483
2491
};
@@ -2504,7 +2512,7 @@ void reduSaveFinalResultToUserMemHelper(
2504
2512
if constexpr (!Reduction::is_usm) {
2505
2513
if (Redu.hasUserDiscardWriteAccessor ()) {
2506
2514
event CopyEvent =
2507
- handler:: withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
2515
+ withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
2508
2516
auto InAcc = Redu.getReadAccToPreviousPartialReds (CopyHandler);
2509
2517
auto OutAcc = Redu.getUserDiscardWriteAccessor ();
2510
2518
Redu.associateWithHandler (CopyHandler);
0 commit comments