Skip to content

Commit 6856cce

Browse files
authored
Remove rank_zero_only on DataModule prepare_data (#7945)
Signed-off-by: Max Ehrlich <[email protected]>
1 parent 96433d0 commit 6856cce

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
209209

210210
### Fixed
211211

212+
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
213+
212214
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
213215

214216
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))

pytorch_lightning/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
24-
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
24+
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
2525

2626

2727
class LightningDataModule(CheckpointHooks, DataHooks):
@@ -381,7 +381,7 @@ def test_dataloader():
381381
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
382382
obj = super().__new__(cls)
383383
# track `DataHooks` calls and run `prepare_data` only on rank zero
384-
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
384+
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
385385
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
386386
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
387387
return obj

tests/core/test_datamodules.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
3535
def test_can_prepare_data(local_rank, node_rank):
3636

37+
model = BoringModel()
3738
dm = BoringDataModule()
3839
trainer = Trainer()
3940
trainer.datamodule = dm
@@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
4344
# local rank = 0 (True)
4445
trainer.prepare_data_per_node = True
4546

47+
dm.random_full = None
48+
dm._has_prepared_data = False
4649
local_rank.return_value = 0
4750
assert trainer.local_rank == 0
4851
assert trainer.data_connector.can_prepare_data()
4952

53+
trainer.data_connector.prepare_data(model)
54+
assert dm.random_full is not None
55+
5056
# local rank = 1 (False)
57+
dm.random_full = None
58+
dm._has_prepared_data = False
5159
local_rank.return_value = 1
5260
assert trainer.local_rank == 1
5361
assert not trainer.data_connector.can_prepare_data()
5462

63+
trainer.data_connector.prepare_data(model)
64+
assert dm.random_full is None
65+
5566
# prepare_data_per_node = False (prepare across all nodes)
5667
# global rank = 0 (True)
68+
dm.random_full = None
69+
dm._has_prepared_data = False
5770
trainer.prepare_data_per_node = False
5871
node_rank.return_value = 0
5972
local_rank.return_value = 0
6073
assert trainer.data_connector.can_prepare_data()
6174

75+
trainer.data_connector.prepare_data(model)
76+
assert dm.random_full is not None
77+
6278
# global rank = 1 (False)
79+
dm.random_full = None
80+
dm._has_prepared_data = False
6381
node_rank.return_value = 1
6482
local_rank.return_value = 0
6583
assert not trainer.data_connector.can_prepare_data()
84+
85+
trainer.data_connector.prepare_data(model)
86+
assert dm.random_full is None
87+
6688
node_rank.return_value = 0
6789
local_rank.return_value = 1
6890
assert not trainer.data_connector.can_prepare_data()
6991

92+
trainer.data_connector.prepare_data(model)
93+
assert dm.random_full is None
94+
7095
# 2 dm
7196
# prepar per node = True
7297
# local rank = 0 (True)

0 commit comments

Comments
 (0)