Skip to content

Commit a5f36a8

Browse files
zpcorepytorchmergebot
authored andcommitted
[DTensor] Fix deadlock after fast cache clear (pytorch#168069)
This is the necessary fix for meta-pytorch/autoparallel#256. ### Issue: when we call `_clear_fast_path_sharding_prop_cache()`, and then `get_thread_local_native_sharding_propagator_cache()`, the code will stuck due to deadlock. ### Cause: When you assign to a Python dict key that already exists: ```C++ thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = old_capsule // capsule #1 stored ... clear_DTensor_sharding_propagator_cache() // call to clean up the cache ... get_thread_local_native_sharding_propagator_cache() { std::lock_guard<std::mutex> lock( native_sharding_propagator_cache_cleanup_mutex); // FIRST claims the lock! if (!native_sharding_propagator_cache_DO_NOT_USE.has_value()) { // enter this again because we have cleared the cache. ... // Destroys old_capsule FIRST then stores new_capsule. However, where we destroy the old_capsule, // it will trigger the destructor to claim `native_sharding_propagator_cache_cleanup_mutex` again! thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = new_capsule // SECOND claims the lock before FIRST releases } } ``` Pull Request resolved: pytorch#168069 Approved by: https://github.com/ezyang
1 parent 1c0bf2a commit a5f36a8

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

test/distributed/tensor/test_op_strategy.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
register_op_strategy,
3535
replicate_op_strategy,
3636
)
37-
from torch.distributed.tensor.debug import CommDebugMode
37+
from torch.distributed.tensor.debug import (
38+
_clear_fast_path_sharding_prop_cache,
39+
_clear_python_sharding_prop_cache,
40+
CommDebugMode,
41+
)
3842
from torch.testing._internal.common_utils import run_tests, TestCase
3943
from torch.testing._internal.distributed._tensor.common_dtensor import (
4044
create_local_tensor_test_class,
@@ -479,7 +483,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
479483
del propagator.op_to_schema_info[op_overload]
480484
else:
481485
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
482-
propagator.propagate_op_sharding.cache.cache_clear()
486+
_clear_fast_path_sharding_prop_cache()
487+
_clear_python_sharding_prop_cache()
483488

484489

485490
def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool:
@@ -645,6 +650,28 @@ def test_call_with_different_nontensor_args(self):
645650
self.assertEqual(out1.full_tensor(), out2.full_tensor())
646651

647652

653+
class TestStrategyOperation(DTensorTestBase):
654+
@property
655+
def world_size(self):
656+
return 2
657+
658+
@with_comms
659+
def test_cache_clean(self):
660+
mesh = self.build_device_mesh()
661+
test_op = torch.ops.mylib.numpy_sin
662+
x = torch.randn(2, device=self.device_type)
663+
y = torch.randn(2, device=self.device_type)
664+
x_dt = distribute_tensor(x, mesh, [Shard(0)])
665+
y_dt = distribute_tensor(y, mesh, [Shard(0)])
666+
with op_strategy_context(test_op.default, replicate_op_strategy):
667+
self._test_op_on_dtensor(test_op, x_dt, y_dt)
668+
with self.assertRaisesRegex(
669+
NotImplementedError,
670+
f"Operator {test_op.default} does not have a sharding strategy registered",
671+
):
672+
self._test_op_on_dtensor(test_op, x_dt, y_dt)
673+
674+
648675
DistTensorReplicateStrategyRegistrationTestWithLocalTensor = (
649676
create_local_tensor_test_class(
650677
DistTensorReplicateStrategyRegistrationTest,

torch/csrc/autograd/python_variable.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,25 +1200,27 @@ get_thread_local_native_sharding_propagator_cache() {
12001200
py::reinterpret_borrow<py::dict>(PyThreadState_GetDict());
12011201
// We need to clean up before Python detaches from the thread if
12021202
// the thread is being destroyed.
1203-
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
1204-
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
1205-
auto* ptid = reinterpret_cast<std::thread::id*>(p);
1206-
{
1207-
std::lock_guard<std::mutex> inner_lock(
1208-
native_sharding_propagator_cache_cleanup_mutex);
1209-
auto it = all_thread_caches.find(*ptid);
1210-
if (it != all_thread_caches.end()) {
1211-
// We need to both:
1212-
// 1) free python objects, and
1213-
it->second->reset();
1214-
// 2) make sure we don't try to come back and mess with
1215-
// a destroyed thread-local at module unload (e.g.,
1216-
// process exit) time.
1217-
all_thread_caches.erase(it);
1203+
if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) {
1204+
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
1205+
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
1206+
auto* ptid = reinterpret_cast<std::thread::id*>(p);
1207+
{
1208+
std::lock_guard<std::mutex> inner_lock(
1209+
native_sharding_propagator_cache_cleanup_mutex);
1210+
auto it = all_thread_caches.find(*ptid);
1211+
if (it != all_thread_caches.end()) {
1212+
// We need to both:
1213+
// 1) free python objects, and
1214+
it->second->reset();
1215+
// 2) make sure we don't try to come back and mess with
1216+
// a destroyed thread-local at module unload (e.g.,
1217+
// process exit) time.
1218+
all_thread_caches.erase(it);
1219+
}
12181220
}
1219-
}
1220-
delete ptid;
1221-
});
1221+
delete ptid;
1222+
});
1223+
}
12221224
}
12231225
return native_sharding_propagator_cache_DO_NOT_USE.value();
12241226
}

0 commit comments

Comments
 (0)