|
| 1 | +import io |
| 2 | +import os |
| 3 | +import os.path |
| 4 | +import pathlib |
| 5 | +from typing import Callable, Optional, Collection |
| 6 | +from typing import Union, Tuple, List, Dict, Any |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch.utils.data import IterDataPipe |
| 10 | +from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter |
| 11 | + |
| 12 | +from torchvision.prototype.datasets.decoder import pil |
| 13 | + |
| 14 | + |
| 15 | +__all__ = ["from_data_folder", "from_image_folder"] |
| 16 | + |
| 17 | +# pseudo-infinite buffer size until a true infinite buffer is supported |
| 18 | +INFINITE = 1_000_000_000 |
| 19 | + |
| 20 | + |
| 21 | +def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: |
| 22 | + rel_path = pathlib.Path(path).relative_to(root) |
| 23 | + return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") |
| 24 | + |
| 25 | + |
| 26 | +def _collate_and_decode_data( |
| 27 | + data: Tuple[str, io.IOBase], |
| 28 | + *, |
| 29 | + root: pathlib.Path, |
| 30 | + categories: List[str], |
| 31 | + decoder, |
| 32 | +) -> Dict[str, Any]: |
| 33 | + path, buffer = data |
| 34 | + data = decoder(buffer) if decoder else buffer |
| 35 | + category = pathlib.Path(path).relative_to(root).parts[0] |
| 36 | + label = torch.tensor(categories.index(category)) |
| 37 | + return dict( |
| 38 | + path=path, |
| 39 | + data=data, |
| 40 | + label=label, |
| 41 | + category=category, |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +def from_data_folder( |
| 46 | + root: Union[str, pathlib.Path], |
| 47 | + *, |
| 48 | + shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE), |
| 49 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, |
| 50 | + valid_extensions: Optional[Collection[str]] = None, |
| 51 | + recursive: bool = True, |
| 52 | +) -> Tuple[IterDataPipe, List[str]]: |
| 53 | + root = pathlib.Path(root).expanduser().resolve() |
| 54 | + categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) |
| 55 | + masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" |
| 56 | + dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) |
| 57 | + dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) |
| 58 | + if shuffler: |
| 59 | + dp = shuffler(dp) |
| 60 | + dp = FileLoader(dp) |
| 61 | + return ( |
| 62 | + Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)), |
| 63 | + categories, |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: |
| 68 | + sample["image"] = sample.pop("data") |
| 69 | + return sample |
| 70 | + |
| 71 | + |
| 72 | +def from_image_folder( |
| 73 | + root: Union[str, pathlib.Path], |
| 74 | + *, |
| 75 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, |
| 76 | + valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), |
| 77 | + **kwargs: Any, |
| 78 | +) -> Tuple[IterDataPipe, List[str]]: |
| 79 | + valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] |
| 80 | + dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) |
| 81 | + return Mapper(dp, _data_to_image_key), categories |
0 commit comments