Skip to content

Commit ee8a57d

Browse files
leoleoasdpre-commit-ci[bot]Bordaawaelchliotaj
authored
Fix usage of fs.listdir in CheckpointConnector (#15413)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent 62d040c commit ee8a57d

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5050

5151
- Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535))
5252

53+
- Fixed an issue when calling `fs.listdir` with file URI instead of path in `CheckpointConnector` ([#15413](https://github.com/Lightning-AI/lightning/pull/15413))
54+
5355
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
5456

5557

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Any, Dict, Optional
2020

2121
import torch
22+
from fsspec.core import url_to_fs
23+
from fsspec.implementations.local import LocalFileSystem
2224
from torch import Tensor
2325
from torchmetrics import Metric
2426

@@ -59,13 +61,16 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH
5961
@property
6062
def _hpc_resume_path(self) -> Optional[str]:
6163
dir_path_hpc = self.trainer.default_root_dir
62-
fs = get_filesystem(dir_path_hpc)
63-
if not fs.isdir(dir_path_hpc):
64-
return None
6564
dir_path_hpc = str(dir_path_hpc)
65+
fs, path = url_to_fs(dir_path_hpc)
66+
if not fs.isdir(path):
67+
return None
6668
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
6769
if max_version is not None:
68-
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
70+
if isinstance(fs, LocalFileSystem):
71+
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
72+
else:
73+
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"
6974

7075
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
7176
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
@@ -565,12 +570,12 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
565570
"""
566571

567572
# check directory existence
568-
fs = get_filesystem(dir_path)
573+
fs, uri = url_to_fs(str(dir_path))
569574
if not fs.exists(dir_path):
570575
return None
571576

572577
# check corresponding file existence
573-
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
578+
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)]
574579
files = [x for x in files if name_key in x]
575580
if len(files) == 0:
576581
return None

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,30 @@ def test_hpc_max_ckpt_version(tmpdir):
102102
)
103103

104104

105+
def test_ckpt_for_fsspec():
106+
"""Test that the CheckpointConnector is able to write to fsspec file systems."""
107+
108+
model = BoringModel()
109+
# hardcoding dir since `tmpdir` can be windows path
110+
trainer = Trainer(
111+
default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1
112+
)
113+
trainer.fit(model)
114+
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt.ckpt")
115+
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_0.ckpt")
116+
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_3.ckpt")
117+
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt")
118+
119+
assert trainer._checkpoint_connector._hpc_resume_path == "memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt"
120+
assert (
121+
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://test_ckpt_for_fsspec")
122+
== 33
123+
)
124+
assert (
125+
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://not_existing") is None
126+
)
127+
128+
105129
def test_loops_restore(tmpdir):
106130
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
107131
model = BoringModel()

0 commit comments

Comments
 (0)