Skip to content

Commit ad88371

Browse files
zou3519facebook-github-bot
authored andcommitted
Implement at::has_internal_overlap helper function (#17926)
Summary: Pull Request resolved: pytorch/pytorch#17926 ghimport-source-id: 9f7572b5d43e474492363fa17dcb86a6c27ca13c Stack: * **#17926 Implement at::has_internal_overlap helper function** * #17927 Error out on in-place (unary) ops on tensors that have internal overlap On the way to #17935. Checks if a tensor's sizes/strides indicate that multiple elements share the same memory location. This problem in general is hard so at::has_internal_overlap implements two heuristics and avoids solving the general problem: if a tensor is contiguous, it cannot have internal overlap if a tensor has any zero strides, it does have internal overlap otherwise, return MemOverlap::kTooHard to indicate that there might be overlap, but we don't know. Reviewed By: ezyang Differential Revision: D14438858 fbshipit-source-id: 607ab31771315921ab6165b2a1f072ac3e75925a
1 parent ea84420 commit ad88371

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <ATen/MemoryOverlap.h>
2+
#include <c10/core/Layout.h>
3+
4+
namespace at {
5+
6+
MemOverlap has_internal_overlap(const Tensor& tensor) {
7+
auto* t = tensor.unsafeGetTensorImpl();
8+
9+
AT_ASSERT(tensor.layout() == kStrided);
10+
11+
if (t->is_contiguous()) {
12+
return MemOverlap::NO;
13+
}
14+
15+
auto strides = t->strides();
16+
if (std::find_if(
17+
strides.begin(), strides.end(), [](int s) { return s == 0; })) {
18+
return MemOverlap::YES;
19+
}
20+
21+
return MemOverlap::TOO_HARD;
22+
}
23+
24+
void assert_no_internal_overlap(const Tensor& t, std::string op) {
25+
if (has_internal_overlap(t) == MemOverlap::YES) {
26+
AT_ERROR(
27+
op, ": unsupported operation: more than one element of the written-to "
28+
"tensor refers to a single memory location. Please clone() the tensor "
29+
"before calling ", op);
30+
}
31+
}
32+
33+
}

aten/src/ATen/MemoryOverlap.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace at {
6+
7+
// MemOverlap: Whether or not there is memory overlap
8+
//
9+
// NO: Absolutely no memory overlap
10+
// YES: Absolutely yes memory overlap
11+
// TOO_HARD: There might be memory overlap, but it was too expensive to compute.
12+
//
13+
// NB: Please update the python test for these if you renumber them.
14+
enum class MemOverlap { NO, YES, TOO_HARD };
15+
16+
MemOverlap has_internal_overlap(const Tensor& t);
17+
18+
void assert_no_internal_overlap(const Tensor& t, std::string op);
19+
20+
}

aten/src/ATen/native/Memory.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/MemoryOverlap.h>
23
#include <ATen/NativeFunctions.h>
34
#include <ATen/detail/CUDAHooksInterface.h>
45
#include <c10/util/Exception.h>
@@ -16,5 +17,10 @@ Tensor pin_memory(const Tensor& self) {
1617
return tensor;
1718
}
1819

20+
// Exposes at::has_internal_overlap as an operator for testing purposes
21+
int64_t _debug_has_internal_overlap(const Tensor& self) {
22+
return static_cast<int64_t>(at::has_internal_overlap(self));
23+
}
24+
1925
}
2026
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161
dispatch:
6262
CUDA: _cudnn_init_dropout_state
6363

64+
- func: _debug_has_internal_overlap(Tensor self) -> int
65+
variants: function
66+
6467
- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
6568
matches_jit_signature: True
6669
variants: function

0 commit comments

Comments
 (0)