|
14 | 14 | import os
|
15 | 15 | from unittest import mock
|
16 | 16 |
|
| 17 | +import fsspec.registry |
17 | 18 | import pytest
|
18 | 19 | import torch
|
| 20 | +from fsspec.implementations.arrow import ArrowFSWrapper |
19 | 21 |
|
20 | 22 | from pytorch_lightning import Trainer
|
21 | 23 | from pytorch_lightning.callbacks import ModelCheckpoint
|
@@ -102,6 +104,38 @@ def test_hpc_max_ckpt_version(tmpdir):
|
102 | 104 | )
|
103 | 105 |
|
104 | 106 |
|
| 107 | +def test_max_ckpt_version_for_fsspec(tmpdir): |
| 108 | + """Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version.""" |
| 109 | + |
| 110 | + class MockFileSystem(ArrowFSWrapper): |
| 111 | + """A wrapper on top of the pyarrow.fs.HadoopFileSystem to connect it's interface with fsspec.""" |
| 112 | + |
| 113 | + protocol = "mock" |
| 114 | + |
| 115 | + def __init__(self, **kwargs): |
| 116 | + from pyarrow.fs import FileSystem |
| 117 | + |
| 118 | + fs = FileSystem.from_uri("mock://") |
| 119 | + super().__init__(fs=fs, **kwargs) |
| 120 | + |
| 121 | + fsspec.registry.register_implementation("mock", MockFileSystem) |
| 122 | + |
| 123 | + model = BoringModel() |
| 124 | + trainer = Trainer(default_root_dir="mock://" + tmpdir, max_steps=1) |
| 125 | + trainer.fit(model) |
| 126 | + trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt") |
| 127 | + trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt") |
| 128 | + trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") |
| 129 | + trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") |
| 130 | + |
| 131 | + assert trainer._checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") |
| 132 | + assert trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33 |
| 133 | + assert ( |
| 134 | + trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing") |
| 135 | + is None |
| 136 | + ) |
| 137 | + |
| 138 | + |
105 | 139 | def test_loops_restore(tmpdir):
|
106 | 140 | """Test that required loop state_dict is loaded correctly by checkpoint connector."""
|
107 | 141 | model = BoringModel()
|
|
0 commit comments