diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index b341d5ae59..079060c10e 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -115,6 +115,8 @@ def __init__( self._repo_and_revision_exists_cache: Dict[ Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] ] = {} + # Maps parent directory path to path infos + self.dircache: Dict[str, List[Dict[str, Any]]] = {} def _repo_and_revision_exist( self, repo_type: str, repo_id: str, revision: Optional[str] @@ -927,6 +929,18 @@ def start_transaction(self): # See https://github.com/huggingface/huggingface_hub/issues/1733 raise NotImplementedError("Transactional commits are not supported.") + def __reduce__(self): + # re-populate the instance cache at HfFileSystem._cache and re-populate the cache attributes of every instance + return make_instance, ( + type(self), + self.storage_args, + self.storage_options, + { + "dircache": self.dircache, + "_repo_and_revision_exists_cache": self._repo_and_revision_exists_cache, + }, + ) + class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): @@ -1127,3 +1141,10 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) + + +def make_instance(cls, args, kwargs, instance_cache_attributes_dict): + fs = cls(*args, **kwargs) + for attr, cached_value in instance_cache_attributes_dict.items(): + setattr(fs, attr, cached_value) + return fs diff --git a/tests/test_hf_file_system.py b/tests/test_hf_file_system.py index 5ba382cb95..6acf0c7e61 100644 --- a/tests/test_hf_file_system.py +++ b/tests/test_hf_file_system.py @@ -2,6 +2,7 @@ import datetime import io import os +import pickle import tempfile import unittest from pathlib import Path @@ -20,7 +21,7 @@ ) from .testing_constants import ENDPOINT_STAGING, TOKEN -from .testing_utils import repo_name, with_production_testing +from .testing_utils import OfflineSimulationMode, offline, repo_name, with_production_testing class HfFileSystemTests(unittest.TestCase): @@ -486,6 +487,20 @@ def test_get_file_on_folder(self): self.hffs.get_file(self.hf_path + "/data", temp_dir + "/data") assert (Path(temp_dir) / "data").exists() + def test_pickle(self): + # Test that pickling re-populates the HfFileSystem cache and keeps the instance cache attributes + fs = HfFileSystem() + fs.isfile(self.text_file) + pickled = pickle.dumps(fs) + HfFileSystem.clear_instance_cache() + with offline(mode=OfflineSimulationMode.CONNECTION_FAILS): + fs = pickle.loads(pickled) + assert isinstance(fs, HfFileSystem) + assert fs in HfFileSystem._cache.values() + assert self.hf_path + "/data" in fs.dircache + assert list(fs._repo_and_revision_exists_cache)[0][1] == self.repo_id + assert fs.isfile(self.text_file) + @pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"]) @pytest.mark.parametrize(