Skip to content

Commit 296d79c

Browse files
kurtamohlermarkc-614
authored andcommitted
[MPS] Update avg_pool3d kernel to use opmath_t (pytorch#161071)
Pull Request resolved: pytorch#161071 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: pytorch#161011
1 parent 58ef970 commit 296d79c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

aten/src/ATen/native/mps/kernels/Pooling.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ void avg_pool_3d_input_iter(
503503
padding,
504504
count_include_pad);
505505

506-
T value_sum = 0;
507-
auto divisor = has_divisor_override
506+
opmath_t<T> value_sum = 0;
507+
opmath_t<T> divisor = has_divisor_override
508508
? divisor_override
509509
: (bounds0.count) * (bounds1.count) * (bounds2.count);
510510

@@ -517,11 +517,11 @@ void avg_pool_3d_input_iter(
517517
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
518518
auto offset2 = input_strides[2] * i2;
519519
auto input_value = input[offset0 + offset1 + offset2];
520-
value_sum += input_value;
520+
value_sum += static_cast<opmath_t<T>>(input_value);
521521
}
522522
}
523523
}
524-
*output = value_sum / static_cast<T>(divisor);
524+
*output = static_cast<T>(value_sum / divisor);
525525
}
526526

527527
// Iterates through all the input elements that this kernel needs to

0 commit comments

Comments
 (0)