Skip to content

Commit 4165870

Browse files
awaelchlicarmocca
andauthored
Remove deadlock detection / process reconciliation logic (#16204)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent d2f3035 commit 4165870

File tree

9 files changed

+5
-145
lines changed

9 files changed

+5
-145
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141

4242
- Removed the deprecated `pytorch_lightning.profiler` module ([#16359](https://github.com/Lightning-AI/lightning/pull/16359))
4343

44+
- Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204))
45+
46+
4447
- Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380))
4548
* save_config_filename
4649
* save_config_overwrite

src/pytorch_lightning/strategies/bagua.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,6 @@ def _set_node_environment_variables(self) -> None:
178178
os.environ["LOCAL_RANK"] = str(self.local_rank)
179179

180180
def setup(self, trainer: "pl.Trainer") -> None:
181-
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
182-
if self._should_run_deadlock_detection():
183-
self._share_information_to_prevent_deadlock()
184-
185181
assert self.accelerator is not None
186182
self.accelerator.setup(trainer)
187183

src/pytorch_lightning/strategies/ddp.py

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import os
16-
import shutil
17-
import signal
18-
import tempfile
19-
import time
2015
from datetime import timedelta
21-
from pathlib import Path
2216
from typing import Any, Callable, Dict, List, Optional, Union
2317

2418
import torch
@@ -52,8 +46,7 @@
5246
from pytorch_lightning.strategies.strategy import TBroadcast
5347
from pytorch_lightning.trainer.states import TrainerFn
5448
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
55-
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
56-
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
49+
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
5750
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
5851

5952
if _FAIRSCALE_AVAILABLE:
@@ -101,9 +94,6 @@ def __init__(
10194
self._ddp_comm_wrapper = ddp_comm_wrapper
10295
self._model_averaging_period = model_averaging_period
10396
self._model_averager: Optional[ModelAverager] = None
104-
self._pids: List[int] = []
105-
self._sync_dir: Optional[str] = None
106-
self._rank_0_will_call_children_scripts: bool = False
10797
self._process_group_backend: Optional[str] = process_group_backend
10898
self._timeout: Optional[timedelta] = timeout
10999

@@ -145,18 +135,12 @@ def _configure_launcher(self) -> None:
145135
assert self.cluster_environment is not None
146136
if not self.cluster_environment.creates_processes_externally:
147137
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
148-
self._rank_0_will_call_children_scripts = True
149138

150139
def setup_environment(self) -> None:
151140
self.setup_distributed()
152141
super().setup_environment()
153142

154143
def setup(self, trainer: "pl.Trainer") -> None:
155-
# share ddp pids to all processes
156-
self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts))
157-
if self._should_run_deadlock_detection():
158-
self._share_information_to_prevent_deadlock()
159-
160144
assert self.accelerator is not None
161145
self.accelerator.setup(trainer)
162146

@@ -391,73 +375,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
391375
description=f"{cls.__class__.__name__}",
392376
)
393377

394-
def _should_run_deadlock_detection(self) -> bool:
395-
"""Determines whether the plugin will perform process reconciliation in case of errors.
396-
397-
If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
398-
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
399-
parent process to perform the process termination, external to Lightning.
400-
"""
401-
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts
402-
403-
def _share_information_to_prevent_deadlock(self) -> None:
404-
self._share_pids()
405-
406-
# there should be a unique sync_dir per nodes.
407-
if self.local_rank == 0:
408-
# create a temporary directory used to synchronize processes on deadlock.
409-
self._sync_dir = tempfile.mkdtemp()
410-
411-
sync_dirs = []
412-
global_node_rank_zero = 0
413-
for _ in range(self.num_nodes):
414-
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
415-
global_node_rank_zero += self.world_size // self.num_nodes
416-
417-
self._sync_dir = sync_dirs[self.node_rank]
418-
419-
def _share_pids(self) -> None:
420-
"""Make all DDP processes aware of all processes pids."""
421-
self.barrier()
422-
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
423-
pids = pids.cpu().numpy().tolist()
424-
self._pids = pids if isinstance(pids, list) else [pids]
425-
426-
def reconciliate_processes(self, trace: str) -> None:
427-
if self.world_size < 2:
428-
return
429-
430-
if not self._should_run_deadlock_detection():
431-
return
432-
433-
sync_dir = self._sync_dir
434-
435-
if not sync_dir:
436-
rank_zero_warn("Error handling mechanism for deadlock detection is uninitialized. Skipping check.")
437-
return
438-
439-
# The cluster may be configured to periodically purge the `/tmp`
440-
# directory, in which case `sync_dir` may not exist anymore at this
441-
# point. Idempotently create it to ensure its existence.
442-
Path(sync_dir).mkdir(parents=True, exist_ok=True)
443-
444-
# save a file locally.
445-
torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))
446-
447-
# sleep for a short time
448-
time.sleep(3)
449-
450-
# return if all processes wrote a file in the `sync_dir`.
451-
# todo (tchaton) Add support for non-shared file-system which will fail.
452-
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
453-
return
454-
455-
for pid in self._pids:
456-
if pid != os.getpid():
457-
os.kill(pid, signal.SIGKILL)
458-
shutil.rmtree(sync_dir)
459-
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
460-
461378
def teardown(self) -> None:
462379
log.detail(f"{self.__class__.__name__}: tearing down strategy")
463380

src/pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def __init__(
145145
self.cpu_offload = _init_cpu_offload(cpu_offload)
146146
self.backward_prefetch = backward_prefetch
147147
self.mixed_precision = mixed_precision
148-
self._rank_0_will_call_children_scripts: bool = False
149148
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
150149
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
151150
activation_checkpointing = activation_checkpointing or []
@@ -215,7 +214,6 @@ def _configure_launcher(self) -> None:
215214
assert self.cluster_environment is not None
216215
if not self.cluster_environment.creates_processes_externally:
217216
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
218-
self._rank_0_will_call_children_scripts = True
219217

220218
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
221219
"""Wraps the model into a
@@ -248,8 +246,6 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
248246
def setup(self, trainer: "pl.Trainer") -> None:
249247
assert self.accelerator is not None
250248
self.accelerator.setup(trainer)
251-
# share ddp pids to all processes
252-
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
253249

254250
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
255251
assert self.model is not None

src/pytorch_lightning/strategies/parallel.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
8282
rank=self.global_rank,
8383
)
8484

85-
def reconciliate_processes(self, trace: str) -> None:
86-
"""Function to re-conciliate processes on failure."""
87-
8885
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
8986
"""Perform a all_gather on all processes."""
9087
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

src/pytorch_lightning/strategies/sharded.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,6 @@ def connect(self, model: "pl.LightningModule") -> None:
6060
return super().connect(model)
6161

6262
def setup(self, trainer: "pl.Trainer") -> None:
63-
# share ddp pids to all processes
64-
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
65-
if self._should_run_deadlock_detection():
66-
self._share_information_to_prevent_deadlock()
67-
6863
assert self.accelerator is not None
6964
self.accelerator.setup(trainer)
7065

src/pytorch_lightning/trainer/call.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import traceback
1514
from typing import Any, Callable
1615

1716
import pytorch_lightning as pl
18-
from lightning_fabric.utilities.distributed import _distributed_available
1917
from pytorch_lightning.trainer.states import TrainerStatus
2018
from pytorch_lightning.utilities.exceptions import _TunerExitException
2119
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@@ -54,9 +52,6 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
5452
logger.finalize("failed")
5553
except BaseException as exception:
5654
trainer.state.status = TrainerStatus.INTERRUPTED
57-
if _distributed_available() and trainer.world_size > 1:
58-
# try syncing remaining processes, kill otherwise
59-
trainer.strategy.reconciliate_processes(traceback.format_exc())
6055
trainer._call_callback_hooks("on_exception", exception)
6156
for logger in trainer.loggers:
6257
logger.finalize("failed")

src/pytorch_lightning/utilities/exceptions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
from lightning_fabric.utilities.exceptions import MisconfigurationException # noqa: F401
1616

1717

18-
class DeadlockDetectedException(Exception):
19-
"""Exception used when a deadlock has been detected and processes are being killed."""
20-
21-
2218
class ExitGracefullyException(SystemExit):
2319
"""Exception used when a ``signal.SIGTERM`` is sent to the process.
2420

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
SingleDeviceStrategy,
6262
)
6363
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
64-
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
64+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
6565
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
6666
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
6767
from tests_pytorch.helpers.datamodules import ClassifDataModule
@@ -1803,41 +1803,6 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
18031803
trainer.predict()
18041804

18051805

1806-
class CustomException(Exception):
1807-
pass
1808-
1809-
1810-
@RunIf(min_cuda_gpus=2, standalone=True)
1811-
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
1812-
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""
1813-
1814-
class TestModel(BoringModel):
1815-
def training_step(self, batch, batch_idx):
1816-
if batch_idx == 1 and self.trainer.is_global_zero:
1817-
# rank 0: raises an exception
1818-
# rank 1: continues training but will hang on the next barrier in the training loop
1819-
raise CustomException
1820-
return super().training_step(batch, batch_idx)
1821-
1822-
model = TestModel()
1823-
1824-
trainer = Trainer(
1825-
default_root_dir=tmpdir,
1826-
max_epochs=1,
1827-
limit_train_batches=5,
1828-
num_sanity_val_steps=0,
1829-
accelerator="gpu",
1830-
devices=2,
1831-
strategy="ddp",
1832-
enable_progress_bar=False,
1833-
enable_model_summary=False,
1834-
)
1835-
1836-
# simulate random failure in training_step on rank 0
1837-
with pytest.raises(DeadlockDetectedException, match="CustomException"):
1838-
trainer.fit(model)
1839-
1840-
18411806
@RunIf(min_cuda_gpus=1)
18421807
def test_multiple_trainer_constant_memory_allocated(tmpdir):
18431808
"""This tests ensures calling the trainer several times reset the memory back to 0."""

0 commit comments

Comments
 (0)