File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
aten/src/ATen/native/mps/kernels Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -503,8 +503,8 @@ void avg_pool_3d_input_iter(
503503 padding,
504504 count_include_pad);
505505
506- T value_sum = 0 ;
507- auto divisor = has_divisor_override
506+ opmath_t <T> value_sum = 0 ;
507+ opmath_t <T> divisor = has_divisor_override
508508 ? divisor_override
509509 : (bounds0.count ) * (bounds1.count ) * (bounds2.count );
510510
@@ -517,11 +517,11 @@ void avg_pool_3d_input_iter(
517517 for (auto i2 = bounds2.start ; i2 < bounds2.end ; i2++) {
518518 auto offset2 = input_strides[2 ] * i2;
519519 auto input_value = input[offset0 + offset1 + offset2];
520- value_sum += input_value;
520+ value_sum += static_cast < opmath_t <T>>( input_value) ;
521521 }
522522 }
523523 }
524- *output = value_sum / static_cast <T>(divisor);
524+ *output = static_cast <T>(value_sum / divisor);
525525}
526526
527527// Iterates through all the input elements that this kernel needs to
You can’t perform that action at this time.
0 commit comments