Skip to content

Commit aedd397

Browse files
authored
return features instead of vanilla tensors from prototype datasets (#4864)
* return features instead of vanilla tensors from prototype datasets * fix tests * remove inplace * add explanation for __init_subclass__ * fix label for test split * relax test * remove pixels
1 parent 775129b commit aedd397

File tree

11 files changed

+122
-92
lines changed

11 files changed

+122
-92
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.testing import make_tensor as _make_tensor
1515
from torchdata.datapipes.iter import IterDataPipe
1616
from torchvision.prototype import datasets
17+
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER
1718
from torchvision.prototype.datasets._api import find
1819
from torchvision.prototype.utils._internal import add_suggestion
1920

@@ -99,28 +100,16 @@ def _get(self, dataset, config):
99100
self._cache[(name, config)] = mock_resources, mock_info
100101
return mock_resources, mock_info
101102

102-
def _decoder(self, dataset_type):
103-
def to_bytes(file):
104-
try:
105-
return file.read()
106-
finally:
107-
file.close()
108-
109-
if dataset_type == datasets.utils.DatasetType.RAW:
110-
return datasets.decoder.raw
111-
else:
112-
return to_bytes
113-
114103
def load(
115-
self, name: str, decoder=DEFAULT_TEST_DECODER, split="train", **options: Any
104+
self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any
116105
) -> Tuple[IterDataPipe, Dict[str, Any]]:
117106
dataset = find(name)
118107
config = dataset.info.make_config(split=split, **options)
119108
resources, mock_info = self._get(dataset, config)
120109
datapipe = dataset._make_datapipe(
121110
[resource.to_datapipe() for resource in resources],
122111
config=config,
123-
decoder=self._decoder(dataset.info.type) if decoder is DEFAULT_TEST_DECODER else decoder,
112+
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
124113
)
125114
return datapipe, mock_info
126115

test/test_prototype_builtin_datasets.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,55 @@
1-
import functools
21
import io
32

43
import builtin_dataset_mocks
54
import pytest
65
from torchdata.datapipes.iter import IterDataPipe
7-
from torchvision.prototype import datasets
6+
from torchvision.prototype import datasets, features
7+
from torchvision.prototype.datasets._api import DEFAULT_DECODER
88
from torchvision.prototype.utils._internal import sequence_to_str
99

1010

11-
_loaders = []
12-
_datasets = []
13-
14-
# TODO: this can be replaced by torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
15-
TMP = [
16-
"mnist",
17-
"fashionmnist",
18-
"kmnist",
19-
"emnist",
20-
"qmnist",
21-
"cifar10",
22-
"cifar100",
23-
"caltech256",
24-
"caltech101",
25-
"imagenet",
26-
]
27-
for name in TMP:
28-
loader = functools.partial(builtin_dataset_mocks.load, name)
29-
_loaders.append(pytest.param(loader, id=name))
30-
31-
info = datasets.info(name)
32-
_datasets.extend(
33-
[
34-
pytest.param(*loader(**config), id=f"{name}-{'-'.join([str(value) for value in config.values()])}")
35-
for config in info._configs
36-
]
37-
)
38-
39-
loaders = pytest.mark.parametrize("loader", _loaders)
40-
builtin_datasets = pytest.mark.parametrize(("dataset", "mock_info"), _datasets)
11+
def to_bytes(file):
12+
try:
13+
return file.read()
14+
finally:
15+
file.close()
16+
17+
18+
def dataset_parametrization(*names, decoder=to_bytes):
19+
if not names:
20+
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
21+
names = (
22+
"mnist",
23+
"fashionmnist",
24+
"kmnist",
25+
"emnist",
26+
"qmnist",
27+
"cifar10",
28+
"cifar100",
29+
"caltech256",
30+
"caltech101",
31+
"imagenet",
32+
)
33+
34+
params = []
35+
for name in names:
36+
for config in datasets.info(name)._configs:
37+
if name == "imagenet" and config.split == "test":
38+
print()
39+
id = f"{name}-{'-'.join([str(value) for value in config.values()])}"
40+
dataset, mock_info = builtin_dataset_mocks.load(name, decoder=decoder, **config)
41+
params.append(pytest.param(dataset, mock_info, id=id))
42+
43+
return pytest.mark.parametrize(("dataset", "mock_info"), params)
4144

4245

4346
class TestCommon:
44-
@builtin_datasets
47+
@dataset_parametrization()
4548
def test_smoke(self, dataset, mock_info):
4649
if not isinstance(dataset, IterDataPipe):
4750
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
4851

49-
@builtin_datasets
52+
@dataset_parametrization()
5053
def test_sample(self, dataset, mock_info):
5154
try:
5255
sample = next(iter(dataset))
@@ -59,15 +62,15 @@ def test_sample(self, dataset, mock_info):
5962
if not sample:
6063
raise AssertionError("Sample dictionary is empty.")
6164

62-
@builtin_datasets
65+
@dataset_parametrization()
6366
def test_num_samples(self, dataset, mock_info):
6467
num_samples = 0
6568
for _ in dataset:
6669
num_samples += 1
6770

6871
assert num_samples == mock_info["num_samples"]
6972

70-
@builtin_datasets
73+
@dataset_parametrization()
7174
def test_decoding(self, dataset, mock_info):
7275
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
7376
if undecoded_features:
@@ -76,6 +79,12 @@ def test_decoding(self, dataset, mock_info):
7679
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
7780
)
7881

82+
@dataset_parametrization(decoder=DEFAULT_DECODER)
83+
def test_at_least_one_feature(self, dataset, mock_info):
84+
sample = next(iter(dataset))
85+
if not any(isinstance(value, features.Feature) for value in sample.values()):
86+
raise AssertionError("The sample contained no feature.")
87+
7988

8089
class TestQMNIST:
8190
@pytest.mark.parametrize(

torchvision/prototype/datasets/_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def info(name: str) -> DatasetInfo:
4949
return find(name).info
5050

5151

52-
default = object()
52+
DEFAULT_DECODER = object()
5353

54-
DEFAULT_DECODER: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
54+
DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
5555
DatasetType.RAW: raw,
5656
DatasetType.IMAGE: pil,
5757
}
@@ -60,15 +60,15 @@ def info(name: str) -> DatasetInfo:
6060
def load(
6161
name: str,
6262
*,
63-
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = default, # type: ignore[assignment]
63+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
6464
split: str = "train",
6565
**options: Any,
6666
) -> IterDataPipe[Dict[str, Any]]:
6767
name = name.lower()
6868
dataset = find(name)
6969

70-
if decoder is default:
71-
decoder = DEFAULT_DECODER.get(dataset.info.type)
70+
if decoder is DEFAULT_DECODER:
71+
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
7272

7373
config = dataset.info.make_config(split=split, **options)
7474
root = os.path.join(home(), name)

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DatasetType,
2323
)
2424
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
25+
from torchvision.prototype.features import Label, BoundingBox
2526

2627

2728
class Caltech101(Dataset):
@@ -95,8 +96,8 @@ def _collate_and_decode_sample(
9596
image = decoder(image_buffer) if decoder else image_buffer
9697

9798
ann = read_mat(ann_buffer)
98-
bbox = torch.as_tensor(ann["box_coord"].astype(np.int64))
99-
contour = torch.as_tensor(ann["obj_contour"])
99+
bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
100+
contour = torch.tensor(ann["obj_contour"].T)
100101

101102
return dict(
102103
category=category,
@@ -171,9 +172,9 @@ def _collate_and_decode_sample(
171172

172173
dir_name = pathlib.Path(path).parent.name
173174
label_str, category = dir_name.split(".")
174-
label = torch.tensor(int(label_str))
175+
label = Label(int(label_str), category=category)
175176

176-
return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer)
177+
return dict(label=label, image=decoder(buffer) if decoder else buffer)
177178

178179
def _make_datapipe(
179180
self,

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
image_buffer_from_array,
2929
path_comparator,
3030
)
31+
from torchvision.prototype.features import Label, Image
3132

3233
__all__ = ["Cifar10", "Cifar100"]
3334

@@ -65,17 +66,16 @@ def _collate_and_decode(
6566
) -> Dict[str, Any]:
6667
image_array, category_idx = data
6768

68-
category = self.categories[category_idx]
69-
label = torch.tensor(category_idx)
70-
71-
image: Union[torch.Tensor, io.BytesIO]
69+
image: Union[Image, io.BytesIO]
7270
if decoder is raw:
73-
image = torch.from_numpy(image_array)
71+
image = Image(image_array)
7472
else:
7573
image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0)))
76-
image = decoder(image_buffer) if decoder else image_buffer
74+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
75+
76+
label = Label(category_idx, category=self.categories[category_idx])
7777

78-
return dict(label=label, category=category, image=image)
78+
return dict(image=image, label=label)
7979

8080
def _make_datapipe(
8181
self,

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,22 @@
2121
getitem,
2222
read_mat,
2323
)
24+
from torchvision.prototype.features import Label, DEFAULT
2425
from torchvision.prototype.utils._internal import FrozenMapping
2526

2627

28+
class ImageNetLabel(Label):
29+
wnid: Optional[str]
30+
31+
@classmethod
32+
def _parse_meta_data(
33+
cls,
34+
category: Optional[str] = DEFAULT, # type: ignore[assignment]
35+
wnid: Optional[str] = DEFAULT, # type: ignore[assignment]
36+
) -> Dict[str, Tuple[Any, Any]]:
37+
return dict(category=(category, None), wnid=(wnid, None))
38+
39+
2740
class ImageNet(Dataset):
2841
def _make_info(self) -> DatasetInfo:
2942
name = "imagenet"
@@ -78,12 +91,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
7891

7992
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
8093

81-
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
94+
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
8295
path = pathlib.Path(data[0])
8396
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
8497
category = self.wnid_to_category[wnid]
85-
label = self.categories.index(category)
86-
return (label, category, wnid), data
98+
label = ImageNetLabel(self.categories.index(category), category=category, wnid=wnid)
99+
return label, data
87100

88101
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
89102

@@ -93,31 +106,27 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
93106

94107
def _collate_val_data(
95108
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
96-
) -> Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]]:
109+
) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
97110
label_data, image_data = data
98111
_, label = label_data
99112
category = self.categories[label]
100113
wnid = self.category_to_wnid[category]
101-
return (label, category, wnid), image_data
114+
return ImageNetLabel(label, category=category, wnid=wnid), image_data
102115

103-
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
104-
return (None, None, None), data
116+
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]:
117+
return None, data
105118

106119
def _collate_and_decode_sample(
107120
self,
108-
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
121+
data: Tuple[Optional[ImageNetLabel], Tuple[str, io.IOBase]],
109122
*,
110123
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
111124
) -> Dict[str, Any]:
112-
ann_data, image_data = data
113-
label, category, wnid = ann_data
114-
path, buffer = image_data
125+
label, (path, buffer) = data
115126
return dict(
116127
path=path,
117128
image=decoder(buffer) if decoder else buffer,
118129
label=label,
119-
category=category,
120-
wnid=wnid,
121130
)
122131

123132
def _make_datapipe(

torchvision/prototype/datasets/decoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
1313

1414

1515
def pil(buffer: io.IOBase) -> features.Image:
16-
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
16+
try:
17+
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
18+
finally:
19+
buffer.close()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._bounding_box import BoundingBoxFormat, BoundingBox
2-
from ._feature import Feature
2+
from ._feature import Feature, DEFAULT
33
from ._image import Image, ColorSpace
44
from ._label import Label

torchvision/prototype/features/_bounding_box.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> T
115115
data = cls._TO_XYXY_MAP[format](data)
116116
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
117117
*_, w, h = to_parts(data)
118-
return int(h.ceil()), int(w.ceil())
118+
if data.dtype.is_floating_point:
119+
w = w.ceil()
120+
h = h.ceil()
121+
return int(h), int(w)
119122

120123
@classmethod
121124
def from_parts(

0 commit comments

Comments
 (0)