Skip to content

Commit 75204a7

Browse files
pytorchmergebotChao1Han
authored andcommitted
Revert "[ATen] Fix CUDA reduction warp shuffle order (pytorch#164790)"
This reverts commit 8e1f409. Reverted pytorch#164790 on behalf of https://github.com/jeffdaily due to broke cuda and rocm ci ([comment](pytorch#164790 (comment)))
1 parent 476910f commit 75204a7

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -655,17 +655,12 @@ struct ReduceOp {
655655
}
656656

657657
__syncthreads();
658-
// Warp-level reduction for remaining threads
659-
// For non-power-of-2 sizes, we start from the next power-of-2 divided by 2
660-
// and use a boundary check to avoid out-of-bounds access
661-
for (size_t offset = warpSize / 2; offset > 0; offset >>= 1) {
658+
659+
for (int offset = 1; offset < dim_x; offset <<= 1) {
662660
#pragma unroll
663661
for (int i = 0; i < output_vec_size; i++) {
664662
arg_t other = ops.warp_shfl_down(value[i], offset);
665-
// Only combine if the source thread (threadIdx.x + offset) is within bounds
666-
if (threadIdx.x + offset < dim_x) {
667-
value[i] = ops.combine(value[i], other);
668-
}
663+
value[i] = ops.combine(value[i], other);
669664
}
670665
}
671666
return value;

aten/src/ATen/native/cuda/reduction_template.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,11 @@ struct ReduceJitOp {
466466
467467
__syncthreads();
468468
469-
for (size_t offset = warpSize / 2; offset > 0; offset >>= 1) {
469+
for (int offset = 1; offset < dim_x; offset <<= 1) {
470470
#pragma unroll
471471
for (int i = 0; i < output_vec_size; i++) {
472472
arg_t other = reducer::warp_shfl_down(value[i], offset);
473-
if (threadIdx.x + offset < dim_x) {
474-
value[i] = reducer::combine(value[i], other);
475-
}
473+
value[i] = reducer::combine(value[i], other);
476474
}
477475
}
478476
return value;

0 commit comments

Comments
 (0)