File tree Expand file tree Collapse file tree 2 files changed +5
-12
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 2 files changed +5
-12
lines changed Original file line number Diff line number Diff line change @@ -655,17 +655,12 @@ struct ReduceOp {
655655 }
656656
657657 __syncthreads ();
658- // Warp-level reduction for remaining threads
659- // For non-power-of-2 sizes, we start from the next power-of-2 divided by 2
660- // and use a boundary check to avoid out-of-bounds access
661- for (size_t offset = warpSize / 2 ; offset > 0 ; offset >>= 1 ) {
658+
659+ for (int offset = 1 ; offset < dim_x; offset <<= 1 ) {
662660 #pragma unroll
663661 for (int i = 0 ; i < output_vec_size; i++) {
664662 arg_t other = ops.warp_shfl_down (value[i], offset);
665- // Only combine if the source thread (threadIdx.x + offset) is within bounds
666- if (threadIdx .x + offset < dim_x) {
667- value[i] = ops.combine (value[i], other);
668- }
663+ value[i] = ops.combine (value[i], other);
669664 }
670665 }
671666 return value;
Original file line number Diff line number Diff line change @@ -466,13 +466,11 @@ struct ReduceJitOp {
466466
467467 __syncthreads();
468468
469- for (size_t offset = warpSize / 2 ; offset > 0 ; offset >> = 1) {
469+ for (int offset = 1 ; offset < dim_x ; offset << = 1) {
470470 #pragma unroll
471471 for (int i = 0; i < output_vec_size; i++) {
472472 arg_t other = reducer::warp_shfl_down(value[i], offset);
473- if (threadIdx.x + offset < dim_x) {
474- value[i] = reducer::combine(value[i], other);
475- }
473+ value[i] = reducer::combine(value[i], other);
476474 }
477475 }
478476 return value;
You can’t perform that action at this time.
0 commit comments