Skip to content

Commit 65d3a87

Browse files
authored
add instructions how to add tests for prototype datasets (#5666)
1 parent be462be commit 65d3a87

File tree

1 file changed

+148
-108
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+148
-108
lines changed
Lines changed: 148 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
# How to add new built-in prototype datasets
22

3-
As the name implies, the datasets are still in a prototype state and thus
4-
subject to rapid change. This in turn means that this document will also change
5-
a lot.
3+
As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means
4+
that this document will also change a lot.
65

7-
If you hit a blocker while adding a dataset, please have a look at another
8-
similar dataset to see how it is implemented there. If you can't resolve it
9-
yourself, feel free to send a draft PR in order for us to help you out.
6+
If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented
7+
there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out.
108

119
Finally, `from torchvision.prototype import datasets` is implied below.
1210

1311
## Implementation
1412

15-
Before we start with the actual implementation, you should create a module in
16-
`torchvision/prototype/datasets/_builtin` that hints at the dataset you are
17-
going to add. For example `caltech.py` for `caltech101` and `caltech256`. In
18-
that module create a class that inherits from `datasets.utils.Dataset` and
19-
overwrites at minimum three methods that will be discussed in detail below:
13+
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin`
14+
that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that
15+
module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be
16+
discussed in detail below:
2017

2118
```python
2219
from typing import Any, Dict, List
@@ -39,50 +36,39 @@ class MyDataset(Dataset):
3936

4037
### `_make_info(self)`
4138

42-
The `DatasetInfo` carries static information about the dataset. There are two
43-
required fields:
44-
- `name`: Name of the dataset. This will be used to load the dataset with
45-
`datasets.load(name)`. Should only contain lowercase characters.
39+
The `DatasetInfo` carries static information about the dataset. There are two required fields:
40+
41+
- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain
42+
lowercase characters.
4643

4744
There are more optional parameters that can be passed:
4845

49-
- `dependencies`: Collection of third-party dependencies that are needed to load
50-
the dataset, e.g. `("scipy",)`. Their availability will be automatically
51-
checked if a user tries to load the dataset. Within the implementation, import
46+
- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their
47+
availability will be automatically checked if a user tries to load the dataset. Within the implementation, import
5248
these packages lazily to avoid missing dependencies at import time.
53-
- `categories`: Sequence of human-readable category names for each label. The
54-
index of each category has to match the corresponding label returned in the
55-
dataset samples. [See
56-
below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle
57-
cases with many categories.
58-
- `valid_options`: Configures valid options that can be passed to the dataset.
59-
It should be `Dict[str, Sequence[Any]]`. The options are accessible through
60-
the `config` namespace in the other two functions. First value of the sequence
61-
is taken as default if the user passes no option to
62-
`torchvision.prototype.datasets.load()`.
49+
- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the
50+
corresponding label returned in the dataset samples.
51+
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
52+
- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[Any]]`.
53+
The options are accessible through the `config` namespace in the other two functions. First value of the sequence is
54+
taken as default if the user passes no option to `torchvision.prototype.datasets.load()`.
6355

6456
## `resources(self, config)`
6557

66-
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be
67-
present locally before the dataset with a specific `config` can be build. The
68-
download will happen automatically.
58+
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a
59+
specific `config` can be build. The download will happen automatically.
6960

7061
Currently, the following `OnlineResource`'s are supported:
7162

72-
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and
73-
only requires the URL.
74-
- `GDriveResource`: Used for files that are hosted on GDrive and requires the
75-
GDrive ID as well as the `file_name`.
76-
- `ManualDownloadResource`: Used files are not publicly accessible and requires
77-
instructions how to download them manually. If the file does not exist, an
78-
error will be raised with the supplied instructions.
79-
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This
80-
inherits from `ManualDownloadResource`.
81-
82-
Although optional in general, all resources used in the built-in datasets should
83-
comprise [SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It
84-
will be automatically checked after the download. You can compute the checksum
85-
with system utilities e.g `sha256-sum`, or this snippet:
63+
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL.
64+
- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`.
65+
- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them
66+
manually. If the file does not exist, an error will be raised with the supplied instructions.
67+
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This inherits from `ManualDownloadResource`.
68+
69+
Although optional in general, all resources used in the built-in datasets should comprise
70+
[SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the
71+
download. You can compute the checksum with system utilities e.g `sha256-sum`, or this snippet:
8672

8773
```python
8874
import hashlib
@@ -97,61 +83,123 @@ def sha256sum(path, chunk_size=1024 * 1024):
9783

9884
### `_make_datapipe(resource_dps, *, config)`
9985

100-
This method is the heart of the dataset, where we transform the raw data into
101-
a usable form. A major difference compared to the current stable datasets is
102-
that everything is performed through `IterDataPipe`'s. From the perspective of
103-
someone that is working with them rather than on them, `IterDataPipe`'s behave
104-
just as generators, i.e. you can't do anything with them besides iterating.
86+
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
87+
to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone
88+
that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything
89+
with them besides iterating.
10590

106-
Of course, there are some common building blocks that should suffice in 95% of
107-
the cases. The most used are:
91+
Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are:
10892

109-
- `Mapper`: Apply a callable to every item in the datapipe.
93+
- `Mapper`: Apply a callable to every item in the datapipe.
11094
- `Filter`: Keep only items that satisfy a condition.
11195
- `Demultiplexer`: Split a datapipe into multiple ones.
11296
- `IterKeyZipper`: Merge two datapipes into one.
11397

114-
All of them can be imported `from torchdata.datapipes.iter`. In addition, use
115-
`functools.partial` in case a callable needs extra arguments. If the provided
116-
`IterDataPipe`'s are not sufficient for the use case, it is also not complicated
98+
All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable
99+
needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated
117100
to add one. See the MNIST or CelebA datasets for example.
118101

119-
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has
120-
a 1-to-1 correspondence with the return value of `resources()`. In case of
121-
archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
122-
tuples comprised of the path and the handle for every file in the archive.
123-
Otherwise the datapipe will only contain one of such tuples for the file
124-
specified by the resource.
102+
`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
103+
value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
104+
tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one
105+
of such tuples for the file specified by the resource.
106+
107+
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and
108+
`Grouper`. There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since
109+
most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered
110+
that is only needed at runtime.
111+
112+
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
113+
trying to zip already loaded images.
114+
115+
There are two special datapipes that are not used through their class, but through the functions `hint_sharding` and
116+
`hint_shuffling`. As the name implies they only hint part in the datapipe graph where sharding and shuffling should take
117+
place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are
118+
required in each dataset.
119+
120+
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
121+
names (yet!).
122+
123+
## Tests
124+
125+
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
126+
This mock-up should resemble the original data as close as necessary, while containing only few examples.
127+
128+
To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the
129+
same name as you have defined in `_make_config()` (if the name includes hyphens `-`, replace them with underscores `_`)
130+
and decorate it with `@register_mock`:
125131

126-
Since the datapipes are iterable in nature, some datapipes feature an in-memory
127-
buffer, e.g. `IterKeyZipper` and `Grouper`. There are two issues with that: 1.
128-
If not used carefully, this can easily overflow the host memory, since most
129-
datasets will not fit in completely. 2. This can lead to unnecessarily long
130-
warm-up times when data is buffered that is only needed at runtime.
132+
```py
133+
# this is defined in torchvision/prototype/datasets/_builtin
134+
class MyDataset(Dataset):
135+
def _make_info(self) -> DatasetInfo:
136+
return DatasetInfo(
137+
"my-dataset",
138+
...
139+
)
140+
141+
@register_mock
142+
def my_dataset(info, root, config):
143+
...
144+
```
131145

132-
Thus, all buffered datapipes should be used as early as possible, e.g. zipping
133-
two datapipes of file handles rather than trying to zip already loaded images.
146+
The function receives three arguments:
134147

135-
There are two special datapipes that are not used through their class, but
136-
through the functions `hint_sharding` and `hint_shuffling`. As the name implies
137-
they only hint part in the datapipe graph where sharding and shuffling should
138-
take place, but are no-ops by default. They can be imported from
139-
`torchvision.prototype.datasets.utils._internal` and are required in each
140-
dataset.
148+
- `info`: The return value of `_make_info()`.
149+
- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data
150+
needs to be placed.
151+
- `config`: The configuration to generate the data for. This is the same value that `_make_datapipe()` receives.
141152

142-
Finally, each item in the final datapipe should be a dictionary with `str` keys.
143-
There is no standardization of the names (yet!).
153+
The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if
154+
the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current
155+
`config`. Although this seems odd at first, this is important. Consider the following original data setup:
156+
157+
```
158+
root
159+
├── test
160+
│ ├── test_image0.jpg
161+
│ ...
162+
└── train
163+
├── train_image0.jpg
164+
...
165+
```
166+
167+
For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to
168+
load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is
169+
present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in
170+
`_make_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data
171+
for the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real
172+
data.
173+
174+
For datasets that are ported from the old API, we already have some mock data in
175+
[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there
176+
and have a look at the `inject_fake_data` function. There are a few differences though:
177+
178+
- `tmp_dir` corresponds to `root`, but is a `str` rather than a
179+
[`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like
180+
`folder = pathlib.Path(tmp_dir)`. This is not needed.
181+
- Although both parameters are called `config`, the value in the new tests is a namespace. Thus, please use `config.foo`
182+
over `config["foo"]` to enhance readability.
183+
- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the
184+
new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files
185+
specified in the dataset.
186+
- As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config.
187+
Make sure you follow the instructions above.
188+
189+
The function should return an integer indicating the number of samples in the dataset for the current `config`.
190+
Preferably, this number should be different for different `config`'s to have more confidence in the dataset
191+
implementation.
192+
193+
Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets.py -k {name}`.
144194

145195
## FAQ
146196

147197
### How do I start?
148198

149-
Get the skeleton of your dataset class ready with all 3 methods. For
150-
`_make_datapipe()`, you can just do `return resources_dp[0]` to get started.
151-
Then import the dataset class in
152-
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically
153-
register the dataset and it will be instantiable via
154-
`datasets.load("mydataset")`. On a separate script, try something like
199+
Get the skeleton of your dataset class ready with all 3 methods. For `_make_datapipe()`, you can just do
200+
`return resources_dp[0]` to get started. Then import the dataset class in
201+
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset and it will be
202+
instantiable via `datasets.load("mydataset")`. On a separate script, try something like
155203

156204
```py
157205
from torchvision.prototype import datasets
@@ -163,35 +211,27 @@ for sample in dataset:
163211
# Or you can also inspect the sample in a debugger
164212
```
165213

166-
This will give you an idea of what the first datapipe in `resources_dp`
167-
contains. You can also do that with `resources_dp[1]` or `resources_dp[2]`
168-
(etc.) if they exist. Then follow the instructions above to manipulate these
214+
This will give you an idea of what the first datapipe in `resources_dp` contains. You can also do that with
215+
`resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these
169216
datapipes and return the appropriate dictionary format.
170217

171218
### How do I handle a dataset that defines many categories?
172219

173-
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only
174-
be set directly for ten categories or fewer. If more categories are needed, you
175-
can add a `$NAME.categories` file to the `_builtin` folder in which each line
176-
specifies a category. If `$NAME` matches the name of the dataset (which it
177-
definitively should!) it will be automatically loaded if `categories=` is not
178-
set.
179-
180-
In case the categories can be generated from the dataset files, e.g. the dataset
181-
follows an image folder approach where each folder denotes the name of the
182-
category, the dataset can overwrite the `_generate_categories` method. It gets
183-
passed the `root` path to the resources, but they have to be manually loaded,
184-
e.g. `self.resources(config)[0].load(root)`. The method should return a sequence
185-
of strings representing the category names. To generate the `$NAME.categories`
186-
file, run `python -m torchvision.prototype.datasets.generate_category_files
187-
$NAME`.
220+
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or
221+
fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line
222+
specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be
223+
automatically loaded if `categories=` is not set.
224+
225+
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
226+
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets
227+
passed the `root` path to the resources, but they have to be manually loaded, e.g.
228+
`self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names.
229+
To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
188230

189231
### What if a resource file forms an I/O bottleneck?
190232

191-
In general, we are ok with small performance hits of iterating archives rather
192-
than their extracted content. However, if the performance hit becomes
193-
significant, the archives can still be decompressed or extracted. To do this,
194-
the `decompress: bool` and `extract: bool` flags can be used for every
195-
`OnlineResource` individually. For more complex cases, each resource also
196-
accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw
197-
file and should return `pathlib.Path` of the preprocessed file or folder.
233+
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
234+
the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the
235+
`decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex
236+
cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should
237+
return `pathlib.Path` of the preprocessed file or folder.

0 commit comments

Comments
 (0)