Skip to content

Commit 40cc616

Browse files
cdzhanpytorchmergebot
authored andcommitted
Fix caching allocator of out-of-tree device is destructed before the … (pytorch#126677)
…destruction of tensors cached by autocast ## Root Cause For out-of-tree device extension it is loaded after torch (different .so), so the global variable `cached_casts` may be constructed before caching allocator and then destructed in reversed order when exit. ## Fix Lazily initialize `cached_casts` to correct the order. ## How to Reproduce && Test Modify the testcase `TestAutocastGPU.test_cast_cache_is_global` in test/test_autocast.py to run on your out-of-tree device. You will see following failure in the end of test. ```bash ---------------------------------------------------------------------- Ran 1 test in 4.812s OK free: 0x30080ff44000400 terminate called after throwing an instance of 'c10::Error' what(): invalid device pointer: 0x30080ff44000400 Exception raised from free at /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/framework/core/caching_allocator.cpp:1609 (most recent call first): frame #0: <unknown function> + 0x118fe1 (0x7ffaef4d3fe1 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #1: <unknown function> + 0x11b1c4 (0x7ffaef4d61c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #2: <unknown function> + 0x117677 (0x7ffaef4d2677 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#3: <unknown function> + 0x11a2bf (0x7ffaef4d52bf in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#4: <unknown function> + 0x11a186 (0x7ffaef4d5186 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#5: <unknown function> + 0x119fde (0x7ffaef4d4fde in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#6: <unknown function> + 0x119d2e (0x7ffaef4d4d2e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#7: <unknown function> + 0x119be0 (0x7ffaef4d4be0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#8: <unknown function> + 0x119977 (0x7ffaef4d4977 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#9: <unknown function> + 0x119313 (0x7ffaef4d4313 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#10: <unknown function> + 0x118b4c (0x7ffaef4d3b4c in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#11: c10::Error::Error(c10::SourceLocation, std::string) + 0x34 (0x7ffaef4d27c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#12: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x7f (0x7ffaef4d04ed in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#13: torch_mlu::MLUCachingAllocator::Native::NativeCachingAllocator::free(void*) + 0xe6 (0x7ff9a8eeb112 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame pytorch#14: torch_mlu::MLUCachingAllocator::Native::local_raw_delete(void*) + 0x3b (0x7ff9a8ed9480 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame pytorch#15: std::unique_ptr<void, void (*)(void*)>::~unique_ptr() + 0x50 (0x7ffb0a5ea322 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame pytorch#16: <unknown function> + 0x1269890 (0x7ffb0a5e4890 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame pytorch#17: <unknown function> + 0x1269928 (0x7ffb0a5e4928 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame pytorch#18: <unknown function> + 0x127572c (0x7ffb0a5f072c in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame pytorch#19: <unknown function> + 0x1275758 (0x7ffb0a5f0758 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame pytorch#20: <unknown function> + 0xb9bc7 (0x7ffaef474bc7 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#21: <unknown function> + 0xb97bc (0x7ffaef4747bc in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#22: <unknown function> + 0xdbc50 (0x7ffaef496c50 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#23: c10::TensorImpl::~TensorImpl() + 0x82 (0x7ffaef49157e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#24: c10::TensorImpl::~TensorImpl() + 0x1c (0x7ffaef4915aa in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame pytorch#25: <unknown function> + 0x2f596d9 (0x7ffaf24fc6d9 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#26: <unknown function> + 0x2f589c2 (0x7ffaf24fb9c2 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#27: <unknown function> + 0x2f57b92 (0x7ffaf24fab92 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#28: <unknown function> + 0x2f5c228 (0x7ffaf24ff228 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#29: <unknown function> + 0x30f3f70 (0x7ffaf2696f70 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#30: <unknown function> + 0x30f3f90 (0x7ffaf2696f90 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#31: <unknown function> + 0x30f5004 (0x7ffaf2698004 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#32: <unknown function> + 0x30f5024 (0x7ffaf2698024 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#33: <unknown function> + 0x31207f0 (0x7ffaf26c37f0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#34: <unknown function> + 0x3120814 (0x7ffaf26c3814 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#35: <unknown function> + 0x30f51e8 (0x7ffaf26981e8 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#36: <unknown function> + 0x30f5148 (0x7ffaf2698148 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#37: <unknown function> + 0x316ecea (0x7ffaf2711cea in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame pytorch#38: <unknown function> + 0x468a7 (0x7ffb0c9ed8a7 in /lib/x86_64-linux-gnu/libc.so.6) frame pytorch#39: on_exit + 0 (0x7ffb0c9eda60 in /lib/x86_64-linux-gnu/libc.so.6) <omitting python frames> frame pytorch#47: __libc_start_main + 0xf3 (0x7ffb0c9cb083 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` Pull Request resolved: pytorch#126677 Approved by: https://github.com/ezyang
1 parent 51c07f9 commit 40cc616

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

aten/src/ATen/autocast_mode.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ namespace {
3535
// directly against incoming TensorImpl*s.
3636
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
3737
using val_type = std::tuple<weakref_type, Tensor>;
38-
ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
38+
39+
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
40+
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
41+
return cached_casts;
42+
}
3943
std::mutex cached_casts_mutex;
4044

4145

@@ -82,7 +86,7 @@ thread_local bool cache_enabled = true;
8286

8387
void clear_cache() {
8488
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
85-
cached_casts.clear();
89+
get_cached_casts().clear();
8690
}
8791

8892
int increment_nesting() {
@@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
124128

125129
if (can_try_cache) {
126130
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
127-
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
128-
if (it != cached_casts.end()) {
131+
auto it = get_cached_casts().find(arg.unsafeGetTensorImpl());
132+
if (it != get_cached_casts().end()) {
129133
return std::get<1>(it->second);
130134
} else {
131135
auto casted_arg = arg.to(to_type);
132-
cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
136+
get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg});
133137
return casted_arg;
134138
}
135139
} else {

0 commit comments

Comments
 (0)