Skip to content

Commit b56b002

Browse files
malfetpytorchmergebot
authored andcommitted
Fix NULL dereference in binary CPU ops (#115183)
Targeted fix for #113037 A more fundamental one, where those functions are not even called for empty tensors are coming later Pull Request resolved: #115183 Approved by: https://github.com/drisspg, https://github.com/atalman, https://github.com/huydhn
1 parent 892a14a commit b56b002

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ void mul_kernel(TensorIteratorBase& iter) {
130130
using comp_t = c10::complex<float>;
131131
return comp_t{a} * comp_t{b};
132132
});
133-
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
133+
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
134134
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "mul_cpu_reduced_float", [&]() {
135135
using opmath_t = at::opmath_type<scalar_t>;
136136
opmath_t b = iter.original_scalar_value<opmath_t>(2);
@@ -162,7 +162,7 @@ void mul_kernel(TensorIteratorBase& iter) {
162162

163163
void div_true_kernel(TensorIteratorBase& iter) {
164164
const auto dtype = iter.common_dtype();
165-
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
165+
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
166166
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "div_cpu_reduced_float", [&]() {
167167
using opmath_t = at::opmath_type<scalar_t>;
168168
opmath_t b = iter.original_scalar_value<opmath_t>(2);
@@ -208,7 +208,7 @@ void div_trunc_kernel(TensorIteratorBase& iter) {
208208
return a / b;
209209
});
210210
});
211-
} else if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
211+
} else if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
212212
AT_DISPATCH_REDUCED_FLOATING_TYPES(
213213
dtype, "div_trunc_cpu_reduced_float", [&]() {
214214
using opmath_t = at::opmath_type<scalar_t>;
@@ -283,7 +283,7 @@ void div_floor_kernel(TensorIteratorBase& iter) {
283283
});
284284
} else {
285285
// See NOTE: [Floor Division in Python]
286-
if (iter.is_scalar(2) && at::isReducedFloatingType(dtype)) {
286+
if (iter.is_scalar(2) && iter.data_ptr(2) != nullptr && at::isReducedFloatingType(dtype)) {
287287
AT_DISPATCH_REDUCED_FLOATING_TYPES(
288288
dtype, "div_floor_cpu_reduced_float", [&]() {
289289
using opmath_t = at::opmath_type<scalar_t>;

test/test_foreach.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,17 @@ def _pointwise_test(
349349
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
350350
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
351351
# TODO: enable empty list case
352-
for tensors in [[torch.randn([0], device=device, dtype=dtype)]]:
352+
for tensors in [[torch.randn([0], device=device, dtype=dtype)],
353+
[torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)]]:
353354
res = torch._foreach_add(tensors, 1)
354355
self.assertEqual(res, tensors)
355356

356357
torch._foreach_add_(tensors, 1)
357358
self.assertEqual(res, tensors)
358359

360+
# Regression test for https://github.com/pytorch/pytorch/issues/113156
361+
torch._foreach_mul_(tensors, 1)
362+
359363
@ops(
360364
filter(lambda op: op.supports_out, foreach_binary_op_db),
361365
dtypes=OpDTypes.supported,

test/test_numpy_interop.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ def test_numpy_scalar_cmp(self, device, dtype):
482482
else:
483483
self.assertTrue(t == a)
484484

485+
@onlyCPU
486+
def test_empty_tensors_interop(self, device):
487+
x = torch.rand((), dtype=torch.float16)
488+
y = torch.tensor(np.random.rand(0), dtype=torch.float16)
489+
# Same can be achieved by running
490+
# y = torch.empty_strided((0,), (0,), dtype=torch.float16)
491+
492+
# Regression test for https://github.com/pytorch/pytorch/issues/115068
493+
self.assertEqual(torch.true_divide(x, y).shape, y.shape)
494+
# Regression test for https://github.com/pytorch/pytorch/issues/115066
495+
self.assertEqual(torch.mul(x, y).shape, y.shape)
496+
# Regression test for https://github.com/pytorch/pytorch/issues/113037
497+
self.assertEqual(torch.div(x, y, rounding_mode='floor').shape, y.shape)
498+
499+
485500
instantiate_device_type_tests(TestNumPyInterop, globals())
486501

487502
if __name__ == '__main__':

0 commit comments

Comments
 (0)