Skip to content

Commit 8d8a38e

Browse files
zou3519facebook-github-bot
authored andcommitted
Error out on in-place (unary) ops on tensors that have internal overlap (#17927)
Summary: Pull Request resolved: pytorch/pytorch#17927 ghimport-source-id: 626d321e430b6b5c0ea3aa1eb9df8c1e2d058bf8 Stack: * #17926 Implement at::has_internal_overlap helper function * **#17927 Error out on in-place (unary) ops on tensors that have internal overlap** On the way to #17935. Works for CPU and CUDA on the following ops: - abs_, acos_, asin_, atan_, ceil_, cos_, erf_, erfc_, exp_, expm1_ - floor_, log_, log10_, log1p_, log2_, round_, rsqrt_, - sin_, sqrt_, tan_, tanh_, trunc_ This PR adds a check to see if the out/result tensor has internal overlap. If it does, then we error out because the result **may** be incorrect. This is overly conservative; there are some cases where if the result is the same as the input, the inplace operation is OK (such as floor_, round_, and trunc_). However, the current code isn't organized in such a way that this is easy to check, so enabling those will come in the future. Reviewed By: ezyang Differential Revision: D14438871 fbshipit-source-id: 15e12bf1fdb2ab7f74bb806e22bc74840bd6abd1
1 parent ad88371 commit 8d8a38e

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,30 @@
44
namespace at {
55

66
MemOverlap has_internal_overlap(const Tensor& tensor) {
7-
auto* t = tensor.unsafeGetTensorImpl();
7+
return has_internal_overlap(tensor.unsafeGetTensorImpl());
8+
}
89

9-
AT_ASSERT(tensor.layout() == kStrided);
10+
MemOverlap has_internal_overlap(TensorImpl* t) {
11+
AT_ASSERT(t->layout() == kStrided);
1012

1113
if (t->is_contiguous()) {
1214
return MemOverlap::NO;
1315
}
1416

1517
auto strides = t->strides();
16-
if (std::find_if(
17-
strides.begin(), strides.end(), [](int s) { return s == 0; })) {
18+
if (strides.end() != std::find_if(
19+
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
1820
return MemOverlap::YES;
1921
}
2022

2123
return MemOverlap::TOO_HARD;
2224
}
2325

2426
void assert_no_internal_overlap(const Tensor& t, std::string op) {
27+
assert_no_internal_overlap(t.unsafeGetTensorImpl(), op);
28+
}
29+
30+
void assert_no_internal_overlap(TensorImpl* t, std::string op) {
2531
if (has_internal_overlap(t) == MemOverlap::YES) {
2632
AT_ERROR(
2733
op, ": unsupported operation: more than one element of the written-to "

aten/src/ATen/MemoryOverlap.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ namespace at {
1313
// NB: Please update the python test for these if you renumber them.
1414
enum class MemOverlap { NO, YES, TOO_HARD };
1515

16-
MemOverlap has_internal_overlap(const Tensor& t);
16+
CAFFE2_API MemOverlap has_internal_overlap(const Tensor& t);
17+
CAFFE2_API MemOverlap has_internal_overlap(TensorImpl* t);
1718

18-
void assert_no_internal_overlap(const Tensor& t, std::string op);
19+
CAFFE2_API void assert_no_internal_overlap(const Tensor& t, std::string op);
20+
CAFFE2_API void assert_no_internal_overlap(TensorImpl* t, std::string op);
1921

2022
}

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/CPUGenerator.h>
88
#include <ATen/CheckGenerator.h>
99
#include <ATen/Generator.h>
10+
#include <ATen/MemoryOverlap.h>
1011
#include <ATen/cpu/vml.h>
1112
#include <ATen/CPUApplyUtils.h>
1213
#include <ATen/native/DispatchStub.h>
@@ -183,6 +184,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
183184
result.data<scalar_t>(), self.data<scalar_t>(), self.numel()); \
184185
\
185186
} else { \
187+
assert_no_internal_overlap(result, #op); \
186188
static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t); \
187189
CPU_tensor_parallel_kernel_apply2<scalar_t, scalar_t>( \
188190
result, \
@@ -211,7 +213,6 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
211213
}); \
212214
} \
213215
REGISTER_DISPATCH(op##Impl, &op##_kernel)
214-
215216
} // anonymous namespace
216217

217218
REGISTER_DISPATCH(sigmoidImpl, &sigmoid_kernel)

aten/src/THC/generic/THCTensorMathPointwise.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define THC_GENERIC_FILE "THC/generic/THCTensorMathPointwise.cu"
33
#else
44

5+
#include <ATen/MemoryOverlap.h>
6+
57
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL) \
68
struct Tensor_##NAME##_##REAL##_Op { \
79
__device__ __forceinline__ void operator()(scalar_t* out, scalar_t* in) const { \
@@ -15,6 +17,7 @@
1517
\
1618
void THCTensor_(NAME)(THCState* state, THCTensor* self_, THCTensor* src) { \
1719
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \
20+
at::assert_no_internal_overlap(self_, #NAME); \
1821
if (self_ == src) { \
1922
if (!THC_pointwiseApply1<scalar_t>(state, self_, Tensor_##NAME##_##REAL##_Op())) { \
2023
THArgCheck(false, 2, CUTORCH_DIM_WARNING); \

0 commit comments

Comments
 (0)