Skip to content

Commit 2350723

Browse files
authored
Merge branch 'master' into bugfix/7930_parent_module_w_param
2 parents 20d5481 + 22d8266 commit 2350723

File tree

13 files changed

+319
-117
lines changed

13 files changed

+319
-117
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ docs/source/api
1717
docs/source/*.md
1818
docs/source/generated
1919
docs/source/*/generated
20+
docs/source/notebooks
2021

2122
# Byte-compiled / optimized / DLL files
2223
__pycache__/

CHANGELOG.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
186186
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
187187

188188

189+
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))
190+
191+
189192
### Removed
190193

191194
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
@@ -211,8 +214,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
211214

212215
### Fixed
213216

214-
215-
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
217+
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
216218

217219

218220
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
@@ -230,6 +232,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
230232
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
231233

232234

235+
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))
236+
237+
238+
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
239+
240+
233241
## [1.3.5] - 2021-06-08
234242

235243
### Added

docs/source/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
122122
nbsphinx_allow_errors = True
123123
nbsphinx_requirejs_path = ''
124124

125+
# myst-parser, forcing to parse all html pages with mathjax
126+
# https://github.com/executablebooks/MyST-Parser/issues/394
127+
myst_update_mathjax = False
128+
125129
# The suffix(es) of source filenames.
126130
# You can specify multiple suffix as a list of string:
127131
#

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121

2222
import pytorch_lightning
2323
from pytorch_lightning.core.lightning import LightningModule
24-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn
24+
from pytorch_lightning.utilities import (
25+
_OMEGACONF_AVAILABLE,
26+
DeviceType,
27+
rank_zero_deprecation,
28+
rank_zero_info,
29+
rank_zero_warn,
30+
)
2531
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
26-
from pytorch_lightning.utilities.cloud_io import load as pl_load
2732
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2833
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
2934

@@ -45,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]:
4550
dir_path_hpc = str(self.trainer.weights_save_path)
4651
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
4752
if max_version is not None:
48-
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"
53+
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
4954

5055
def resume_start(self) -> None:
5156
"""
@@ -129,6 +134,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
129134
# hook: give user access to checkpoint if needed.
130135
model.on_load_checkpoint(checkpoint)
131136

137+
# call hpc specific hook
138+
if self.hpc_resume_path is not None:
139+
model.on_hpc_load(self._loaded_checkpoint)
140+
132141
# restore model state_dict
133142
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
134143

@@ -248,6 +257,7 @@ def restore_lr_schedulers(self) -> None:
248257
# ----------------------------------
249258
# PRIVATE OPS
250259
# ----------------------------------
260+
251261
def hpc_save(self, folderpath: str, logger):
252262
# make sure the checkpoint folder exists
253263
folderpath = str(folderpath) # because the tests pass a path object
@@ -365,29 +375,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
365375

366376
return checkpoint
367377

368-
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
369-
"""
370-
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
371-
All restored states are listed in return value description of `dump_checkpoint`.
378+
def hpc_load(self, checkpoint_path: str) -> None:
372379
"""
380+
Attempts to restore the full training and model state from a HPC checkpoint file.
373381
374-
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
375-
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
376-
377-
# acquire the model
378-
model = self.trainer.lightning_module
379-
380-
# restore model and datamodule state
381-
self.restore_model_state(model, checkpoint)
382-
383-
if self.trainer.root_gpu is not None:
384-
model.cuda(self.trainer.root_gpu)
385-
386-
# restore training state
387-
self.restore_training_state(checkpoint)
388-
389-
# call hpc specific hook
390-
model.on_hpc_load(checkpoint)
382+
.. deprecated::v1.4
383+
Will be removed in v1.6. Use :meth:`restore` instead.
384+
"""
385+
rank_zero_deprecation(
386+
"`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6."
387+
" Use `CheckpointConnector.restore()` instead."
388+
)
389+
self.restore(checkpoint_path)
391390

392391
def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
393392
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

pytorch_lightning/trainer/predict_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _get_num_dataloaders(self, dataloaders: List[DataLoader]) -> int:
9898

9999
def _build_kwargs(self, batch, batch_idx, dataloader_idx):
100100
step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])
101-
if self.num_dataloaders:
101+
if self.num_dataloaders > 1:
102102
step_kwargs['dataloader_idx'] = dataloader_idx
103103
return step_kwargs
104104

pytorch_lightning/utilities/seed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def reset_seed() -> None:
8484
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
8585
"""
8686
seed = os.environ.get("PL_GLOBAL_SEED", None)
87+
workers = os.environ.get("PL_SEED_WORKERS", False)
8788
if seed is not None:
88-
seed_everything(int(seed))
89+
seed_everything(int(seed), workers=bool(workers))
8990

9091

9192
def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
@@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p
100101
process_seed = torch.initial_seed()
101102
# back out the base seed so we can use all the bits
102103
base_seed = process_seed - worker_id
104+
log.debug(
105+
f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}'
106+
)
103107
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
104108
# use 128 bits (4 x 32-bit words)
105109
np.random.seed(ss.generate_state(4))

tests/deprecated_api/test_remove_1-4.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx):
6666

6767
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
6868
trainer.fit(TestModel())
69+
70+
71+
def test_v1_4_0_deprecated_hpc_load(tmpdir):
72+
model = BoringModel()
73+
trainer = Trainer(
74+
default_root_dir=tmpdir,
75+
max_steps=1,
76+
)
77+
trainer.fit(model)
78+
trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger)
79+
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir))
80+
with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"):
81+
trainer.checkpoint_connector.hpc_load(checkpoint_path)

tests/helpers/boring_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,24 @@ def __init__(self, data_dir: str = './'):
161161
self.checkpoint_state: Optional[str] = None
162162

163163
def prepare_data(self):
164-
self.random_full = RandomDataset(32, 192)
164+
self.random_full = RandomDataset(32, 64 * 4)
165165

166166
def setup(self, stage: Optional[str] = None):
167167
if stage == "fit" or stage is None:
168168
self.random_train = Subset(self.random_full, indices=range(64))
169169
self.dims = self.random_train[0].shape
170170

171171
if stage in ("fit", "validate") or stage is None:
172-
self.random_val = Subset(self.random_full, indices=range(64, 128))
172+
self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
173173

174174
if stage == "test" or stage is None:
175-
self.random_test = Subset(self.random_full, indices=range(128, 192))
175+
self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
176176
self.dims = getattr(self, "dims", self.random_test[0].shape)
177177

178+
if stage == "predict" or stage is None:
179+
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
180+
self.dims = getattr(self, "dims", self.random_predict[0].shape)
181+
178182
def train_dataloader(self):
179183
return DataLoader(self.random_train)
180184

@@ -183,3 +187,6 @@ def val_dataloader(self):
183187

184188
def test_dataloader(self):
185189
return DataLoader(self.random_test)
190+
191+
def predict_dataloader(self):
192+
return DataLoader(self.random_predict)

tests/helpers/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run_model_test(
9191
trainer.checkpoint_connector.hpc_save(save_dir, logger)
9292
# test HPC loading
9393
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
94-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
94+
trainer.checkpoint_connector.restore(checkpoint_path)
9595

9696

9797
@torch.no_grad()

tests/models/data/horovod/train_default_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None:
8787
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
8888
# test HPC loading
8989
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
90-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
90+
trainer.checkpoint_connector.restore(checkpoint_path)
9191

9292
if on_gpu:
9393
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)

0 commit comments

Comments
 (0)