Skip to content

Commit 3ab730e

Browse files
ananthsubBorda
andauthored
Swap torch.load for fsspec load in ddp spawn backend (#3787)
* Update ddp_spawn_backend.py * Update ddp_cpu_spawn_backend.py * log Co-authored-by: Jirka Borovec <[email protected]>
1 parent 192fc01 commit 3ab730e

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
- `row_log_interval` and `log_save_interval` are now based on training loop's `global_step` instead of epoch-internal batch index ([#3667](https://github.com/PyTorchLightning/pytorch-lightning/pull/3667))
5151

52+
- Swap `torch.load` for `fsspec` load in DDP spawn backend ([#3787](https://github.com/PyTorchLightning/pytorch-lightning/pull/3787))
53+
5254
### Deprecated
5355

5456

pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytorch_lightning import _logger as log
2323
from pytorch_lightning.accelerators.base_backend import Accelerator
2424
from pytorch_lightning.utilities import AMPType
25-
from pytorch_lightning.utilities.cloud_io import atomic_save
25+
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
2626
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
2727
from pytorch_lightning.utilities.distributed import find_free_network_port
2828
from pytorch_lightning.distributed.dist import LightningDistributed
@@ -195,7 +195,7 @@ def __recover_child_process_weights(self, model, best_path, last_path):
195195

196196
# load last weights
197197
if last_path is not None and not self.trainer.testing:
198-
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
198+
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
199199
model.load_state_dict(ckpt)
200200

201201
self.trainer.model = model

pytorch_lightning/accelerators/ddp_spawn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytorch_lightning import _logger as log
2323
from pytorch_lightning.accelerators.base_backend import Accelerator
2424
from pytorch_lightning.utilities import AMPType
25-
from pytorch_lightning.utilities.cloud_io import atomic_save
25+
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
2626
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
2727
from pytorch_lightning.utilities.seed import seed_everything
2828
from pytorch_lightning.distributed.dist import LightningDistributed
@@ -210,7 +210,7 @@ def __recover_child_process_weights(self, model, best_path, last_path):
210210

211211
# load last weights
212212
if last_path is not None and not self.trainer.testing:
213-
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
213+
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
214214
model.load_state_dict(ckpt)
215215

216216
self.trainer.model = model

0 commit comments

Comments
 (0)