diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index ef3ae7af896..10402a00f18 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -2,9 +2,11 @@ from PIL import Image +import io import os import os.path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple +import zipfile +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: @@ -139,8 +141,8 @@ def __init__( self.samples = samples self.targets = [s[1] for s in samples] - @staticmethod def make_dataset( + self, directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, @@ -190,15 +192,16 @@ def __len__(self) -> int: IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') -def pil_loader(path: str) -> Image.Image: +def pil_loader(path: Union[str, io.BytesIO]) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: - img = Image.open(f) - return img.convert('RGB') + f = open(path, 'rb') if isinstance(path, str) else path + img = Image.open(f).convert('RGB') + f.close() + return img # TODO: specify the return type -def accimage_loader(path: str) -> Any: +def accimage_loader(path: Union[str, io.BytesIO]) -> Any: import accimage try: return accimage.Image(path) @@ -207,7 +210,7 @@ def accimage_loader(path: str) -> Any: return pil_loader(path) -def default_loader(path: str) -> Any: +def default_loader(path: Union[str, io.BytesIO]) -> Any: from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) @@ -255,3 +258,79 @@ def __init__( target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples + + +class ZipFolder(DatasetFolder): + def __init__(self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, memory: bool = True) -> None: + if not root.endswith('.zip'): + raise TypeError("Need ZIP file for data source: ", root) + if memory: + with open(root, 'rb') as z: + data = z.read() + self.root_zip = zipfile.ZipFile(io.BytesIO(data), 'r') + else: + self.root_zip = zipfile.ZipFile(root, 'r') + super().__init__(root, self.zip_loader, IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, target_transform=target_transform, is_valid_file=is_valid_file) + self.imgs = self.samples + + @staticmethod + def initialize_from_folder(root: str, zip_path: str = None): + root = os.path.normpath(root) + folder_dir, folder_base = os.path.split(root) + if zip_path is None: + zip_path = os.path.join(folder_dir, f'{folder_base}_store.zip') + with zipfile.ZipFile(zip_path, mode='w', compression=zipfile.ZIP_STORED) as zf: + for walk_root, walk_dirs, walk_files in os.walk(root): + # TODO: (python 3.9) zip_root = walk_root.removeprefix(folder_dir) + zip_root = walk_root[len(folder_dir):] if walk_root.startswith(folder_dir) else walk_root + for _file in walk_files: + org_path = os.path.join(walk_root, _file) + zip_path = os.path.join(zip_root, _file) + zf.write(org_path, zip_path) + + def make_dataset( + self, + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + instances = [] + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for filepath in self.root_zip.namelist(): + if is_valid_file(filepath): + target_class = os.path.basename(os.path.dirname(filepath)) + instances.append((filepath, class_to_idx[target_class])) + return instances + + def zip_loader(self, path: str) -> Any: + return default_loader(io.BytesIO(self.root_zip.read(path))) + + def _find_classes(self, *args, **kwargs): + """ + Finds the class folders in a dataset. + Args: + dir (string): Root directory path. + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + Ensures: + No class is a subdirectory of another. + """ + classes = set() + for filepath in self.root_zip.namelist(): + root, target_class = os.path.split(os.path.dirname(filepath)) + if root: + classes.add(target_class) + classes = list(classes) + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx