@@ -1339,7 +1339,7 @@ sycl::event reduction_over_group_temps_strided_impl(
1339
1339
static_cast <py::ssize_t >(remaining_reduction_nelems)};
1340
1340
ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
1341
1341
/* shape */ iter_shape_and_strides,
1342
- /* s trides */ iter_shape_and_strides +
1342
+ /* strides */ iter_shape_and_strides +
1343
1343
2 * iter_nd};
1344
1344
1345
1345
InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
@@ -1424,8 +1424,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
1424
1424
py::ssize_t reduction_arg_offset,
1425
1425
const std::vector<sycl::event> &depends)
1426
1426
{
1427
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1428
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1427
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1428
+ iter_arg_offset + reduction_arg_offset;
1429
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
1429
1430
1430
1431
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1431
1432
@@ -1767,8 +1768,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
1767
1768
py::ssize_t reduction_arg_offset,
1768
1769
const std::vector<sycl::event> &depends)
1769
1770
{
1770
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
1771
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
1771
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1772
+ iter_arg_offset + reduction_arg_offset;
1773
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
1772
1774
1773
1775
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1774
1776
@@ -4258,8 +4260,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4258
4260
py::ssize_t reduction_arg_offset,
4259
4261
const std::vector<sycl::event> &depends)
4260
4262
{
4261
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4262
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4263
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4264
+ iter_arg_offset + reduction_arg_offset;
4265
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
4263
4266
4264
4267
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4265
4268
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
@@ -4635,8 +4638,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4635
4638
py::ssize_t reduction_arg_offset,
4636
4639
const std::vector<sycl::event> &depends)
4637
4640
{
4638
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp);
4639
- resTy *res_tp = reinterpret_cast <resTy *>(res_cp);
4641
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
4642
+ iter_arg_offset + reduction_arg_offset;
4643
+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
4640
4644
4641
4645
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4642
4646
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
0 commit comments