Skip to content

Commit 482260d

Browse files
committed
Add tests
1 parent f7aa0c5 commit 482260d

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from typing import Any, Dict, Optional
2020

2121
import torch
22+
from fsspec.core import url_to_fs
2223
from torch import Tensor
2324
from torchmetrics import Metric
24-
from fsspec.core import uri_to_fs
2525

2626
import pytorch_lightning as pl
2727
from lightning_lite.plugins.environments.slurm import SLURMEnvironment
@@ -575,7 +575,7 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
575575
"""
576576

577577
# check directory existence
578-
fs, uri = uri_to_fs(dir_path)
578+
fs, uri = url_to_fs(dir_path)
579579
if not fs.exists(dir_path):
580580
return None
581581

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import os
1515
from unittest import mock
1616

17+
import fsspec.registry
1718
import pytest
1819
import torch
20+
from fsspec.implementations.arrow import ArrowFSWrapper
1921

2022
from pytorch_lightning import Trainer
2123
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -102,6 +104,38 @@ def test_hpc_max_ckpt_version(tmpdir):
102104
)
103105

104106

107+
def test_max_ckpt_version_for_fsspec(tmpdir):
108+
"""Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version."""
109+
110+
class MockFileSystem(ArrowFSWrapper):
111+
"""A wrapper on top of the pyarrow.fs.HadoopFileSystem to connect it's interface with fsspec."""
112+
113+
protocol = "mock"
114+
115+
def __init__(self, **kwargs):
116+
from pyarrow.fs import FileSystem
117+
118+
fs = FileSystem.from_uri("mock://")
119+
super().__init__(fs=fs, **kwargs)
120+
121+
fsspec.registry.register_implementation("mock", MockFileSystem)
122+
123+
model = BoringModel()
124+
trainer = Trainer(default_root_dir="mock://" + tmpdir, max_steps=1)
125+
trainer.fit(model)
126+
trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
127+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
128+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
129+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")
130+
131+
assert trainer._checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
132+
assert trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33
133+
assert (
134+
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing")
135+
is None
136+
)
137+
138+
105139
def test_loops_restore(tmpdir):
106140
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
107141
model = BoringModel()

0 commit comments

Comments
 (0)