Skip to content

Commit fafd254

Browse files
awaelchlijustusschockBorda
authored
Fix device parser logic to avoid creating CUDA context (#14319)
* let environment disable forking * add helper function and error messages * tests * changelog Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 0102d0d commit fafd254

File tree

6 files changed

+39
-2
lines changed

6 files changed

+39
-2
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))
2222

2323

24+
- Added an environment variable `PL_DISABLE_FORK` that can be used to disable all forking in the Trainer ([#14319](https://github.com/Lightning-AI/lightning/issues/14319))
25+
26+
2427

2528
### Changed
2629

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def __init__(self, strategy: Strategy, start_method: Literal["spawn", "fork", "f
6666
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
6767
f" {', '.join(mp.get_all_start_methods())}"
6868
)
69+
if start_method in ("fork", "forkserver") and _is_forking_disabled():
70+
raise ValueError(
71+
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
72+
)
6973

7074
@property
7175
def is_interactive_compatible(self) -> bool:
@@ -281,3 +285,8 @@ def restore(self) -> None:
281285
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
282286
torch.backends.cudnn.benchmark = self.cudnn_benchmark
283287
_set_rng_states(self.rng_states)
288+
289+
290+
def _is_forking_disabled() -> bool:
291+
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
292+
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
TPUSpawnStrategy,
7474
)
7575
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
76+
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
7677
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
7778
from pytorch_lightning.utilities import (
7879
_StrategyType,
@@ -642,6 +643,10 @@ def _check_strategy_and_fallback(self) -> None:
642643
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
643644
f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead."
644645
)
646+
if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled():
647+
raise ValueError(
648+
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy."
649+
)
645650
if strategy_flag:
646651
self._strategy_flag = strategy_flag
647652

src/pytorch_lightning/utilities/device_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.cuda
1919

2020
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
21+
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
2122
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2223
from pytorch_lightning.utilities.types import _DEVICE
2324

@@ -323,7 +324,7 @@ def num_cuda_devices() -> int:
323324
Unlike :func:`torch.cuda.device_count`, this function will do its best not to create a CUDA context for fork
324325
support, if the platform allows it.
325326
"""
326-
if "fork" not in torch.multiprocessing.get_all_start_methods():
327+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
327328
return torch.cuda.device_count()
328329
with multiprocessing.get_context("fork").Pool(1) as pool:
329330
return pool.apply(torch.cuda.device_count)
@@ -335,7 +336,7 @@ def is_cuda_available() -> bool:
335336
Unlike :func:`torch.cuda.is_available`, this function will do its best not to create a CUDA context for fork
336337
support, if the platform allows it.
337338
"""
338-
if "fork" not in torch.multiprocessing.get_all_start_methods():
339+
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
339340
return torch.cuda.is_available()
340341
with multiprocessing.get_context("fork").Pool(1) as pool:
341342
return pool.apply(torch.cuda.is_available)

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from unittest import mock
1516
from unittest.mock import ANY, Mock
1617

1718
import pytest
1819
import torch
1920

2021
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
22+
from tests_pytorch.helpers.runif import RunIf
2123

2224

2325
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
@@ -26,6 +28,14 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
2628
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
2729

2830

31+
@RunIf(skip_windows=True)
32+
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
33+
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
34+
def test_multiprocessing_launcher_disabled_forking(start_method):
35+
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
36+
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
37+
38+
2939
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
3040
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
3141
def test_multiprocessing_launcher_start_method(mp_mock, start_method):

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,12 @@ def test_accelerator_specific_checkpoint_io(*_):
810810
def test_ddp_fork_on_unsupported_platform(_, strategy):
811811
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
812812
Trainer(strategy=strategy)
813+
814+
815+
@RunIf(skip_windows=True)
816+
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
817+
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
818+
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
819+
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
820+
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
821+
Trainer(devices=2, strategy=strategy)

0 commit comments

Comments
 (0)