Skip to content

Commit 021df7a

Browse files
pmeierfmassa
andauthored
add prototype image folder dataset (#4441)
* add prototype image folder dataset * remove decoder datapipe * [PROPOSAL] add RandomPicker * refactor data loading * fix mypy * remove per-category datapipes * fix mypy Co-authored-by: Francisco Massa <[email protected]>
1 parent 972ca65 commit 021df7a

File tree

4 files changed

+96
-0
lines changed

4 files changed

+96
-0
lines changed

torchvision/prototype/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import datasets
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import decoder
2+
from ._folder import from_data_folder, from_image_folder
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import io
2+
import os
3+
import os.path
4+
import pathlib
5+
from typing import Callable, Optional, Collection
6+
from typing import Union, Tuple, List, Dict, Any
7+
8+
import torch
9+
from torch.utils.data import IterDataPipe
10+
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
11+
12+
from torchvision.prototype.datasets.decoder import pil
13+
14+
15+
__all__ = ["from_data_folder", "from_image_folder"]
16+
17+
# pseudo-infinite buffer size until a true infinite buffer is supported
18+
INFINITE = 1_000_000_000
19+
20+
21+
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
22+
rel_path = pathlib.Path(path).relative_to(root)
23+
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
24+
25+
26+
def _collate_and_decode_data(
27+
data: Tuple[str, io.IOBase],
28+
*,
29+
root: pathlib.Path,
30+
categories: List[str],
31+
decoder,
32+
) -> Dict[str, Any]:
33+
path, buffer = data
34+
data = decoder(buffer) if decoder else buffer
35+
category = pathlib.Path(path).relative_to(root).parts[0]
36+
label = torch.tensor(categories.index(category))
37+
return dict(
38+
path=path,
39+
data=data,
40+
label=label,
41+
category=category,
42+
)
43+
44+
45+
def from_data_folder(
46+
root: Union[str, pathlib.Path],
47+
*,
48+
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE),
49+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
50+
valid_extensions: Optional[Collection[str]] = None,
51+
recursive: bool = True,
52+
) -> Tuple[IterDataPipe, List[str]]:
53+
root = pathlib.Path(root).expanduser().resolve()
54+
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
55+
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
56+
dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks)
57+
dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
58+
if shuffler:
59+
dp = shuffler(dp)
60+
dp = FileLoader(dp)
61+
return (
62+
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
63+
categories,
64+
)
65+
66+
67+
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
68+
sample["image"] = sample.pop("data")
69+
return sample
70+
71+
72+
def from_image_folder(
73+
root: Union[str, pathlib.Path],
74+
*,
75+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
76+
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
77+
**kwargs: Any,
78+
) -> Tuple[IterDataPipe, List[str]]:
79+
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
80+
dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs)
81+
return Mapper(dp, _data_to_image_key), categories
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import io
2+
3+
import numpy as np
4+
import PIL.Image
5+
import torch
6+
7+
__all__ = ["pil"]
8+
9+
10+
def pil(file: io.IOBase, mode="RGB") -> torch.Tensor:
11+
image = PIL.Image.open(file).convert(mode.upper())
12+
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))

0 commit comments

Comments
 (0)