Skip to content

Commit f9bee39

Browse files
authored
relax split requirement for prototype datasets (#5186)
* relax split requirement for prototype datasets * remove obsolete tests * appease mypy * fix failing test * fix load config test with default config
1 parent 28f72f1 commit f9bee39

File tree

4 files changed

+35
-72
lines changed

4 files changed

+35
-72
lines changed

test/test_prototype_datasets_api.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,21 +125,6 @@ def test_default_config(self, info):
125125

126126
assert info.default_config == default_config
127127

128-
@pytest.mark.parametrize(
129-
"valid_options",
130-
[
131-
pytest.param(None, id="default"),
132-
pytest.param(dict(option=("value",)), id="no_split"),
133-
],
134-
)
135-
def test_default_config_split_train(self, valid_options):
136-
info = make_minimal_dataset_info(valid_options=valid_options)
137-
assert info.default_config.split == "train"
138-
139-
def test_valid_options_split_but_no_train(self):
140-
with pytest.raises(ValueError, match="'train' has to be a valid argument for option 'split'"):
141-
make_minimal_dataset_info(valid_options=dict(split=("test",)))
142-
143128
@pytest.mark.parametrize(
144129
("options", "expected_error_msg"),
145130
[
@@ -208,7 +193,7 @@ def test_default_config(self):
208193
("config", "kwarg"),
209194
[
210195
pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"),
211-
pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
196+
pytest.param(DatasetMock().default_config, None, id="default"),
212197
],
213198
)
214199
def test_load_config(self, config, kwarg):
@@ -218,7 +203,7 @@ def test_load_config(self, config, kwarg):
218203

219204
dataset.resources.assert_called_with(config)
220205

221-
(_, call_kwargs) = dataset._make_datapipe.call_args
206+
_, call_kwargs = dataset._make_datapipe.call_args
222207
assert call_kwargs["config"] == config
223208

224209
def test_missing_dependencies(self):

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,32 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
4545

4646

4747
class _CifarBase(Dataset):
48+
_FILE_NAME: str
49+
_SHA256: str
4850
_LABELS_KEY: str
4951
_META_FILE_NAME: str
5052
_CATEGORIES_KEY: str
5153

5254
@abc.abstractmethod
53-
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
55+
def _is_data_file(self, data: Tuple[str, io.IOBase], *, split: str) -> Optional[int]:
5456
pass
5557

58+
def _make_info(self) -> DatasetInfo:
59+
return DatasetInfo(
60+
type(self).__name__.lower(),
61+
type=DatasetType.RAW,
62+
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
63+
valid_options=dict(split=("train", "test")),
64+
)
65+
66+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
67+
return [
68+
HttpResource(
69+
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
70+
sha256=self._SHA256,
71+
)
72+
]
73+
5674
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
5775
_, file = data
5876
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
@@ -84,7 +102,7 @@ def _make_datapipe(
84102
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
85103
) -> IterDataPipe[Dict[str, Any]]:
86104
dp = resource_dps[0]
87-
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
105+
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
88106
dp = Mapper(dp, self._unpickle)
89107
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
90108
dp = hint_sharding(dp)
@@ -102,53 +120,24 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]:
102120

103121

104122
class Cifar10(_CifarBase):
123+
_FILE_NAME = "cifar-10-python.tar.gz"
124+
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
105125
_LABELS_KEY = "labels"
106126
_META_FILE_NAME = "batches.meta"
107127
_CATEGORIES_KEY = "label_names"
108128

109-
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
129+
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
110130
path = pathlib.Path(data[0])
111-
return path.name.startswith("data" if config.split == "train" else "test")
112-
113-
def _make_info(self) -> DatasetInfo:
114-
return DatasetInfo(
115-
"cifar10",
116-
type=DatasetType.RAW,
117-
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
118-
)
119-
120-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
121-
return [
122-
HttpResource(
123-
"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz",
124-
sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce",
125-
)
126-
]
131+
return path.name.startswith("data" if split == "train" else "test")
127132

128133

129134
class Cifar100(_CifarBase):
135+
_FILE_NAME = "cifar-100-python.tar.gz"
136+
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
130137
_LABELS_KEY = "fine_labels"
131138
_META_FILE_NAME = "meta"
132139
_CATEGORIES_KEY = "fine_label_names"
133140

134-
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
141+
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
135142
path = pathlib.Path(data[0])
136-
return path.name == cast(str, config.split)
137-
138-
def _make_info(self) -> DatasetInfo:
139-
return DatasetInfo(
140-
"cifar100",
141-
type=DatasetType.RAW,
142-
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
143-
valid_options=dict(
144-
split=("train", "test"),
145-
),
146-
)
147-
148-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
149-
return [
150-
HttpResource(
151-
"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz",
152-
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
153-
)
154-
]
143+
return path.name == split

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def _make_info(self) -> DatasetInfo:
3737
type=DatasetType.IMAGE,
3838
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
3939
valid_options=dict(
40-
# FIXME
41-
split=("trainval", "test", "train"),
40+
split=("trainval", "test"),
4241
),
4342
)
4443

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
citation: Optional[str] = None,
3939
homepage: Optional[str] = None,
4040
license: Optional[str] = None,
41-
valid_options: Optional[Dict[str, Sequence]] = None,
41+
valid_options: Optional[Dict[str, Sequence[Any]]] = None,
4242
extra: Optional[Dict[str, Any]] = None,
4343
) -> None:
4444
self.name = name.lower()
@@ -60,20 +60,10 @@ def __init__(
6060
self.homepage = homepage
6161
self.license = license
6262

63-
valid_split: Dict[str, Sequence] = dict(split=["train"])
64-
if valid_options is None:
65-
valid_options = valid_split
66-
elif "split" not in valid_options:
67-
valid_options.update(valid_split)
68-
elif "train" not in valid_options["split"]:
69-
raise ValueError(
70-
f"'train' has to be a valid argument for option 'split', "
71-
f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}."
72-
)
73-
self._valid_options: Dict[str, Sequence] = valid_options
63+
self._valid_options = valid_options or dict()
7464
self._configs = tuple(
75-
DatasetConfig(**dict(zip(valid_options.keys(), combination)))
76-
for combination in itertools.product(*valid_options.values())
65+
DatasetConfig(**dict(zip(self._valid_options.keys(), combination)))
66+
for combination in itertools.product(*self._valid_options.values())
7767
)
7868

7969
self.extra = FrozenBunch(extra or dict())

0 commit comments

Comments
 (0)