Skip to content

Commit ef71159

Browse files
dgenzeldgenzelfmassa
authored
Add iNaturalist dataset (#4123)
* Add iNaturalist dataset * Add download support * address comments Co-authored-by: dgenzel <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent ac83a2c commit ef71159

File tree

4 files changed

+297
-1
lines changed

4 files changed

+297
-1
lines changed

docs/source/datasets.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ ImageNet
122122
.. note ::
123123
This requires `scipy` to be installed
124124
125+
iNaturalist
126+
~~~~~~~~~~~
127+
128+
.. autoclass:: INaturalist
129+
:members: __getitem__, category_name
130+
125131
Kinetics-400
126132
~~~~~~~~~~~~
127133

test/test_datasets.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,5 +1755,43 @@ def test_images_download_preexisting(self):
17551755
pass
17561756

17571757

1758+
class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
1759+
DATASET_CLASS = datasets.INaturalist
1760+
FEATURE_TYPES = (PIL.Image.Image, (int, tuple))
1761+
1762+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
1763+
target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]),
1764+
version=("2021_train",),
1765+
)
1766+
1767+
def inject_fake_data(self, tmpdir, config):
1768+
categories = [
1769+
"00000_Akingdom_0phylum_Aclass_Aorder_Afamily_Agenus_Aspecies",
1770+
"00001_Akingdom_1phylum_Aclass_Border_Afamily_Bgenus_Aspecies",
1771+
"00002_Akingdom_2phylum_Cclass_Corder_Cfamily_Cgenus_Cspecies",
1772+
]
1773+
1774+
num_images_per_category = 3
1775+
for category in categories:
1776+
datasets_utils.create_image_folder(
1777+
root=os.path.join(tmpdir, config["version"]),
1778+
name=category,
1779+
file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg",
1780+
num_examples=num_images_per_category,
1781+
)
1782+
1783+
return num_images_per_category * len(categories)
1784+
1785+
def test_targets(self):
1786+
target_types = ["kingdom", "phylum", "class", "order", "family", "genus", "full"]
1787+
1788+
with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _):
1789+
items = [d[1] for d in dataset]
1790+
for i, item in enumerate(items):
1791+
self.assertEqual(dataset.category_name("kingdom", item[0]), "Akingdom")
1792+
self.assertEqual(dataset.category_name("phylum", item[1]), f"{i // 3}phylum")
1793+
self.assertEqual(item[6], i // 3)
1794+
1795+
17581796
if __name__ == "__main__":
17591797
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .ucf101 import UCF101
2626
from .places365 import Places365
2727
from .kitti import Kitti
28+
from .inaturalist import INaturalist
2829

2930
__all__ = ('LSUN', 'LSUNClass',
3031
'ImageFolder', 'DatasetFolder', 'FakeData',
@@ -35,5 +36,5 @@
3536
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
3637
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
3738
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
38-
'Places365', 'Kitti',
39+
'Places365', 'Kitti', "INaturalist"
3940
)

torchvision/datasets/inaturalist.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
from PIL import Image
2+
import os
3+
import os.path
4+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5+
6+
from .vision import VisionDataset
7+
from .utils import download_and_extract_archive, verify_str_arg
8+
9+
CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
10+
11+
DATASET_URLS = {
12+
'2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz',
13+
'2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz',
14+
'2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz',
15+
'2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz',
16+
'2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz',
17+
'2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz',
18+
}
19+
20+
DATASET_MD5 = {
21+
'2017': '7c784ea5e424efaec655bd392f87301f',
22+
'2018': 'b1c6952ce38f31868cc50ea72d066cc3',
23+
'2019': 'c60a6e2962c9b8ccbd458d12c8582644',
24+
'2021_train': '38a7bb733f7a09214d44293460ec0021',
25+
'2021_train_mini': 'db6ed8330e634445efc8fec83ae81442',
26+
'2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc',
27+
}
28+
29+
30+
class INaturalist(VisionDataset):
31+
"""`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
32+
33+
Args:
34+
root (string): Root directory of dataset where the image files are stored.
35+
This class does not require/use annotation files.
36+
version (string, optional): Which version of the dataset to download/use. One of
37+
'2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
38+
Default: `2021_train`.
39+
target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
40+
41+
- ``full``: the full category (species)
42+
- ``kingdom``: e.g. "Animalia"
43+
- ``phylum``: e.g. "Arthropoda"
44+
- ``class``: e.g. "Insecta"
45+
- ``order``: e.g. "Coleoptera"
46+
- ``family``: e.g. "Cleridae"
47+
- ``genus``: e.g. "Trichodes"
48+
49+
for 2017-2019 versions, one of:
50+
51+
- ``full``: the full (numeric) category
52+
- ``super``: the super category, e.g. "Amphibians"
53+
54+
Can also be a list to output a tuple with all specified target types.
55+
Defaults to ``full``.
56+
transform (callable, optional): A function/transform that takes in an PIL image
57+
and returns a transformed version. E.g, ``transforms.RandomCrop``
58+
target_transform (callable, optional): A function/transform that takes in the
59+
target and transforms it.
60+
download (bool, optional): If true, downloads the dataset from the internet and
61+
puts it in root directory. If dataset is already downloaded, it is not
62+
downloaded again.
63+
"""
64+
65+
def __init__(
66+
self,
67+
root: str,
68+
version: str = "2021_train",
69+
target_type: Union[List[str], str] = "full",
70+
transform: Optional[Callable] = None,
71+
target_transform: Optional[Callable] = None,
72+
download: bool = False,
73+
) -> None:
74+
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
75+
76+
super(INaturalist, self).__init__(os.path.join(root, version),
77+
transform=transform,
78+
target_transform=target_transform)
79+
80+
os.makedirs(root, exist_ok=True)
81+
if download:
82+
self.download()
83+
84+
if not self._check_integrity():
85+
raise RuntimeError('Dataset not found or corrupted.' +
86+
' You can use download=True to download it')
87+
88+
self.all_categories: List[str] = []
89+
90+
# map: category type -> name of category -> index
91+
self.categories_index: Dict[str, Dict[str, int]] = {}
92+
93+
# list indexed by category id, containing mapping from category type -> index
94+
self.categories_map: List[Dict[str, int]] = []
95+
96+
if not isinstance(target_type, list):
97+
target_type = [target_type]
98+
if self.version[:4] == "2021":
99+
self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021))
100+
for t in target_type]
101+
self._init_2021()
102+
else:
103+
self.target_type = [verify_str_arg(t, "target_type", ("full", "super"))
104+
for t in target_type]
105+
self._init_pre2021()
106+
107+
# index of all files: (full category id, filename)
108+
self.index: List[Tuple[int, str]] = []
109+
110+
for dir_index, dir_name in enumerate(self.all_categories):
111+
files = os.listdir(os.path.join(self.root, dir_name))
112+
for fname in files:
113+
self.index.append((dir_index, fname))
114+
115+
def _init_2021(self) -> None:
116+
"""Initialize based on 2021 layout"""
117+
118+
self.all_categories = sorted(os.listdir(self.root))
119+
120+
# map: category type -> name of category -> index
121+
self.categories_index = {
122+
k: {} for k in CATEGORIES_2021
123+
}
124+
125+
for dir_index, dir_name in enumerate(self.all_categories):
126+
pieces = dir_name.split('_')
127+
if len(pieces) != 8:
128+
raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces')
129+
if pieces[0] != f'{dir_index:05d}':
130+
raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}')
131+
cat_map = {}
132+
for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
133+
if name in self.categories_index[cat]:
134+
cat_id = self.categories_index[cat][name]
135+
else:
136+
cat_id = len(self.categories_index[cat])
137+
self.categories_index[cat][name] = cat_id
138+
cat_map[cat] = cat_id
139+
self.categories_map.append(cat_map)
140+
141+
def _init_pre2021(self) -> None:
142+
"""Initialize based on 2017-2019 layout"""
143+
144+
# map: category type -> name of category -> index
145+
self.categories_index = {'super': {}}
146+
147+
cat_index = 0
148+
super_categories = sorted(os.listdir(self.root))
149+
for sindex, scat in enumerate(super_categories):
150+
self.categories_index["super"][scat] = sindex
151+
subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
152+
for subcat in subcategories:
153+
if self.version == "2017":
154+
# this version does not use ids as directory names
155+
subcat_i = cat_index
156+
cat_index += 1
157+
else:
158+
try:
159+
subcat_i = int(subcat)
160+
except ValueError:
161+
raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
162+
if subcat_i >= len(self.categories_map):
163+
old_len = len(self.categories_map)
164+
self.categories_map.extend([{}] * (subcat_i - old_len + 1))
165+
self.all_categories.extend([""] * (subcat_i - old_len + 1))
166+
if self.categories_map[subcat_i]:
167+
raise RuntimeError(f"Duplicate category {subcat}")
168+
self.categories_map[subcat_i] = {'super': sindex}
169+
self.all_categories[subcat_i] = os.path.join(scat, subcat)
170+
171+
# validate the dictionary
172+
for cindex, c in enumerate(self.categories_map):
173+
if not c:
174+
raise RuntimeError(f"Missing category {cindex}")
175+
176+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
177+
"""
178+
Args:
179+
index (int): Index
180+
181+
Returns:
182+
tuple: (image, target) where the type of target specified by target_type.
183+
"""
184+
185+
cat_id, fname = self.index[index]
186+
img = Image.open(os.path.join(self.root,
187+
self.all_categories[cat_id],
188+
fname))
189+
190+
target: Any = []
191+
for t in self.target_type:
192+
if t == "full":
193+
target.append(cat_id)
194+
else:
195+
target.append(self.categories_map[cat_id][t])
196+
target = tuple(target) if len(target) > 1 else target[0]
197+
198+
if self.transform is not None:
199+
img = self.transform(img)
200+
201+
if self.target_transform is not None:
202+
target = self.target_transform(target)
203+
204+
return img, target
205+
206+
def __len__(self) -> int:
207+
return len(self.index)
208+
209+
def category_name(self, category_type: str, category_id: int) -> str:
210+
"""
211+
Args:
212+
category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
213+
category_id(int): an index (class id) from this category
214+
215+
Returns:
216+
the name of the category
217+
"""
218+
if category_type == "full":
219+
return self.all_categories[category_id]
220+
else:
221+
if category_type not in self.categories_index:
222+
raise ValueError(f"Invalid category type '{category_type}'")
223+
else:
224+
for name, id in self.categories_index[category_type].items():
225+
if id == category_id:
226+
return name
227+
raise ValueError(f"Invalid category id {category_id} for {category_type}")
228+
229+
def _check_integrity(self) -> bool:
230+
return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
231+
232+
def download(self) -> None:
233+
if self._check_integrity():
234+
raise RuntimeError(
235+
f"The directory {self.root} already exists. "
236+
f"If you want to re-download or re-extract the images, delete the directory."
237+
)
238+
239+
base_root = os.path.dirname(self.root)
240+
241+
download_and_extract_archive(
242+
DATASET_URLS[self.version],
243+
base_root,
244+
filename=f"{self.version}.tgz",
245+
md5=DATASET_MD5[self.version])
246+
247+
orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
248+
if not os.path.exists(orig_dir_name):
249+
raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
250+
os.rename(orig_dir_name, self.root)
251+
print(f"Dataset version '{self.version}' has been downloaded and prepared for use")

0 commit comments

Comments
 (0)