Skip to content

Commit 32d210e

Browse files
committed
Fix eager reduction warp shuffle order to start from offset=16
ghstack-source-id: 4e8b6c8 Pull Request resolved: pytorch/pytorch#164790
1 parent b9f03a3 commit 32d210e

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,14 @@ struct ReduceOp {
655655
}
656656

657657
__syncthreads();
658-
658+
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
659+
// matching Triton, etc.
660+
// todo for AMD
661+
#ifdef USE_ROCM
659662
for (int offset = 1; offset < dim_x; offset <<= 1) {
663+
#else
664+
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
665+
#endif
660666
#pragma unroll
661667
for (int i = 0; i < output_vec_size; i++) {
662668
arg_t other = ops.warp_shfl_down(value[i], offset);

aten/src/ATen/native/cuda/reduction_template.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,11 @@ struct ReduceJitOp {
466466
467467
__syncthreads();
468468
469+
#ifdef USE_ROCM
469470
for (int offset = 1; offset < dim_x; offset <<= 1) {
471+
#else
472+
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
473+
#endif
470474
#pragma unroll
471475
for (int i = 0; i < output_vec_size; i++) {
472476
arg_t other = reducer::warp_shfl_down(value[i], offset);

test/test_decomp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
220220
(torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
221221
(torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
222222
(torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
223+
(torch.float16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
224+
(torch.bfloat16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
223225
# see https://github.com/pytorch/pytorch/pull/96264
224226
(torch.float16, torch.ops.aten.mv.default): 1e-5,
225227
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
@@ -295,6 +297,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
295297
rtol, atol = tol_table[(decomp.dtype, op)]
296298
else:
297299
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
300+
298301
test_case.assertEqual(
299302
orig,
300303
decomp,

0 commit comments

Comments
 (0)