Skip to content

Commit bfba152

Browse files
The taper optimization in tree-reduction which causes problem with CUDA
The optimization should not use max-work-group-size, to allow RT some of the SLM memory.
1 parent dcb566a commit bfba152

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
10941094
// max_max_wg prevents running out of resources on CPU
10951095
constexpr size_t max_max_wg = 2048;
10961096
size_t max_wg = std::min(
1097-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1097+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
10981098

10991099
size_t reductions_per_wi(preferrered_reductions_per_wi);
11001100
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1444,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
14441444
// max_max_wg prevents running out of resources on CPU
14451445
constexpr size_t max_max_wg = 2048;
14461446
size_t max_wg = std::min(
1447-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1447+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
14481448

14491449
size_t reductions_per_wi(preferrered_reductions_per_wi);
14501450
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -1788,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
17881788
// max_max_wg prevents running out of resources on CPU
17891789
constexpr size_t max_max_wg = 2048;
17901790
size_t max_wg = std::min(
1791-
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
1791+
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);
17921792

17931793
size_t reductions_per_wi(preferrered_reductions_per_wi);
17941794
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -3883,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(
38833883

38843884
constexpr size_t preferrered_reductions_per_wi = 4;
38853885
// max_max_wg prevents running out of resources on CPU
3886-
size_t max_wg = std::min(
3887-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
3886+
size_t max_wg =
3887+
std::min(size_t(2048),
3888+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
38883889

38893890
size_t reductions_per_wi(preferrered_reductions_per_wi);
38903891
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4279,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
42794280

42804281
constexpr size_t preferrered_reductions_per_wi = 8;
42814282
// max_max_wg prevents running out of resources on CPU
4282-
size_t max_wg = std::min(
4283-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
4283+
size_t max_wg =
4284+
std::min(size_t(2048),
4285+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
42844286

42854287
size_t reductions_per_wi(preferrered_reductions_per_wi);
42864288
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
@@ -4657,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
46574659

46584660
constexpr size_t preferrered_reductions_per_wi = 8;
46594661
// max_max_wg prevents running out of resources on CPU
4660-
size_t max_wg = std::min(
4661-
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
4662+
size_t max_wg =
4663+
std::min(size_t(2048),
4664+
d.get_info<sycl::info::device::max_work_group_size>() / 2);
46624665

46634666
size_t reductions_per_wi(preferrered_reductions_per_wi);
46644667
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {

0 commit comments

Comments
 (0)