|
| 1 | +from PIL import Image |
| 2 | +import os |
| 3 | +import os.path |
| 4 | +from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
| 5 | + |
| 6 | +from .vision import VisionDataset |
| 7 | +from .utils import download_and_extract_archive, verify_str_arg |
| 8 | + |
| 9 | +CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"] |
| 10 | + |
| 11 | +DATASET_URLS = { |
| 12 | + '2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz', |
| 13 | + '2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz', |
| 14 | + '2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz', |
| 15 | + '2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz', |
| 16 | + '2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz', |
| 17 | + '2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz', |
| 18 | +} |
| 19 | + |
| 20 | +DATASET_MD5 = { |
| 21 | + '2017': '7c784ea5e424efaec655bd392f87301f', |
| 22 | + '2018': 'b1c6952ce38f31868cc50ea72d066cc3', |
| 23 | + '2019': 'c60a6e2962c9b8ccbd458d12c8582644', |
| 24 | + '2021_train': '38a7bb733f7a09214d44293460ec0021', |
| 25 | + '2021_train_mini': 'db6ed8330e634445efc8fec83ae81442', |
| 26 | + '2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc', |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +class INaturalist(VisionDataset): |
| 31 | + """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset. |
| 32 | +
|
| 33 | + Args: |
| 34 | + root (string): Root directory of dataset where the image files are stored. |
| 35 | + This class does not require/use annotation files. |
| 36 | + version (string, optional): Which version of the dataset to download/use. One of |
| 37 | + '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'. |
| 38 | + Default: `2021_train`. |
| 39 | + target_type (string or list, optional): Type of target to use, for 2021 versions, one of: |
| 40 | +
|
| 41 | + - ``full``: the full category (species) |
| 42 | + - ``kingdom``: e.g. "Animalia" |
| 43 | + - ``phylum``: e.g. "Arthropoda" |
| 44 | + - ``class``: e.g. "Insecta" |
| 45 | + - ``order``: e.g. "Coleoptera" |
| 46 | + - ``family``: e.g. "Cleridae" |
| 47 | + - ``genus``: e.g. "Trichodes" |
| 48 | +
|
| 49 | + for 2017-2019 versions, one of: |
| 50 | +
|
| 51 | + - ``full``: the full (numeric) category |
| 52 | + - ``super``: the super category, e.g. "Amphibians" |
| 53 | +
|
| 54 | + Can also be a list to output a tuple with all specified target types. |
| 55 | + Defaults to ``full``. |
| 56 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 57 | + and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| 58 | + target_transform (callable, optional): A function/transform that takes in the |
| 59 | + target and transforms it. |
| 60 | + download (bool, optional): If true, downloads the dataset from the internet and |
| 61 | + puts it in root directory. If dataset is already downloaded, it is not |
| 62 | + downloaded again. |
| 63 | + """ |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + root: str, |
| 68 | + version: str = "2021_train", |
| 69 | + target_type: Union[List[str], str] = "full", |
| 70 | + transform: Optional[Callable] = None, |
| 71 | + target_transform: Optional[Callable] = None, |
| 72 | + download: bool = False, |
| 73 | + ) -> None: |
| 74 | + self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) |
| 75 | + |
| 76 | + super(INaturalist, self).__init__(os.path.join(root, version), |
| 77 | + transform=transform, |
| 78 | + target_transform=target_transform) |
| 79 | + |
| 80 | + os.makedirs(root, exist_ok=True) |
| 81 | + if download: |
| 82 | + self.download() |
| 83 | + |
| 84 | + if not self._check_integrity(): |
| 85 | + raise RuntimeError('Dataset not found or corrupted.' + |
| 86 | + ' You can use download=True to download it') |
| 87 | + |
| 88 | + self.all_categories: List[str] = [] |
| 89 | + |
| 90 | + # map: category type -> name of category -> index |
| 91 | + self.categories_index: Dict[str, Dict[str, int]] = {} |
| 92 | + |
| 93 | + # list indexed by category id, containing mapping from category type -> index |
| 94 | + self.categories_map: List[Dict[str, int]] = [] |
| 95 | + |
| 96 | + if not isinstance(target_type, list): |
| 97 | + target_type = [target_type] |
| 98 | + if self.version[:4] == "2021": |
| 99 | + self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) |
| 100 | + for t in target_type] |
| 101 | + self._init_2021() |
| 102 | + else: |
| 103 | + self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) |
| 104 | + for t in target_type] |
| 105 | + self._init_pre2021() |
| 106 | + |
| 107 | + # index of all files: (full category id, filename) |
| 108 | + self.index: List[Tuple[int, str]] = [] |
| 109 | + |
| 110 | + for dir_index, dir_name in enumerate(self.all_categories): |
| 111 | + files = os.listdir(os.path.join(self.root, dir_name)) |
| 112 | + for fname in files: |
| 113 | + self.index.append((dir_index, fname)) |
| 114 | + |
| 115 | + def _init_2021(self) -> None: |
| 116 | + """Initialize based on 2021 layout""" |
| 117 | + |
| 118 | + self.all_categories = sorted(os.listdir(self.root)) |
| 119 | + |
| 120 | + # map: category type -> name of category -> index |
| 121 | + self.categories_index = { |
| 122 | + k: {} for k in CATEGORIES_2021 |
| 123 | + } |
| 124 | + |
| 125 | + for dir_index, dir_name in enumerate(self.all_categories): |
| 126 | + pieces = dir_name.split('_') |
| 127 | + if len(pieces) != 8: |
| 128 | + raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces') |
| 129 | + if pieces[0] != f'{dir_index:05d}': |
| 130 | + raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}') |
| 131 | + cat_map = {} |
| 132 | + for cat, name in zip(CATEGORIES_2021, pieces[1:7]): |
| 133 | + if name in self.categories_index[cat]: |
| 134 | + cat_id = self.categories_index[cat][name] |
| 135 | + else: |
| 136 | + cat_id = len(self.categories_index[cat]) |
| 137 | + self.categories_index[cat][name] = cat_id |
| 138 | + cat_map[cat] = cat_id |
| 139 | + self.categories_map.append(cat_map) |
| 140 | + |
| 141 | + def _init_pre2021(self) -> None: |
| 142 | + """Initialize based on 2017-2019 layout""" |
| 143 | + |
| 144 | + # map: category type -> name of category -> index |
| 145 | + self.categories_index = {'super': {}} |
| 146 | + |
| 147 | + cat_index = 0 |
| 148 | + super_categories = sorted(os.listdir(self.root)) |
| 149 | + for sindex, scat in enumerate(super_categories): |
| 150 | + self.categories_index["super"][scat] = sindex |
| 151 | + subcategories = sorted(os.listdir(os.path.join(self.root, scat))) |
| 152 | + for subcat in subcategories: |
| 153 | + if self.version == "2017": |
| 154 | + # this version does not use ids as directory names |
| 155 | + subcat_i = cat_index |
| 156 | + cat_index += 1 |
| 157 | + else: |
| 158 | + try: |
| 159 | + subcat_i = int(subcat) |
| 160 | + except ValueError: |
| 161 | + raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}") |
| 162 | + if subcat_i >= len(self.categories_map): |
| 163 | + old_len = len(self.categories_map) |
| 164 | + self.categories_map.extend([{}] * (subcat_i - old_len + 1)) |
| 165 | + self.all_categories.extend([""] * (subcat_i - old_len + 1)) |
| 166 | + if self.categories_map[subcat_i]: |
| 167 | + raise RuntimeError(f"Duplicate category {subcat}") |
| 168 | + self.categories_map[subcat_i] = {'super': sindex} |
| 169 | + self.all_categories[subcat_i] = os.path.join(scat, subcat) |
| 170 | + |
| 171 | + # validate the dictionary |
| 172 | + for cindex, c in enumerate(self.categories_map): |
| 173 | + if not c: |
| 174 | + raise RuntimeError(f"Missing category {cindex}") |
| 175 | + |
| 176 | + def __getitem__(self, index: int) -> Tuple[Any, Any]: |
| 177 | + """ |
| 178 | + Args: |
| 179 | + index (int): Index |
| 180 | +
|
| 181 | + Returns: |
| 182 | + tuple: (image, target) where the type of target specified by target_type. |
| 183 | + """ |
| 184 | + |
| 185 | + cat_id, fname = self.index[index] |
| 186 | + img = Image.open(os.path.join(self.root, |
| 187 | + self.all_categories[cat_id], |
| 188 | + fname)) |
| 189 | + |
| 190 | + target: Any = [] |
| 191 | + for t in self.target_type: |
| 192 | + if t == "full": |
| 193 | + target.append(cat_id) |
| 194 | + else: |
| 195 | + target.append(self.categories_map[cat_id][t]) |
| 196 | + target = tuple(target) if len(target) > 1 else target[0] |
| 197 | + |
| 198 | + if self.transform is not None: |
| 199 | + img = self.transform(img) |
| 200 | + |
| 201 | + if self.target_transform is not None: |
| 202 | + target = self.target_transform(target) |
| 203 | + |
| 204 | + return img, target |
| 205 | + |
| 206 | + def __len__(self) -> int: |
| 207 | + return len(self.index) |
| 208 | + |
| 209 | + def category_name(self, category_type: str, category_id: int) -> str: |
| 210 | + """ |
| 211 | + Args: |
| 212 | + category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super" |
| 213 | + category_id(int): an index (class id) from this category |
| 214 | +
|
| 215 | + Returns: |
| 216 | + the name of the category |
| 217 | + """ |
| 218 | + if category_type == "full": |
| 219 | + return self.all_categories[category_id] |
| 220 | + else: |
| 221 | + if category_type not in self.categories_index: |
| 222 | + raise ValueError(f"Invalid category type '{category_type}'") |
| 223 | + else: |
| 224 | + for name, id in self.categories_index[category_type].items(): |
| 225 | + if id == category_id: |
| 226 | + return name |
| 227 | + raise ValueError(f"Invalid category id {category_id} for {category_type}") |
| 228 | + |
| 229 | + def _check_integrity(self) -> bool: |
| 230 | + return os.path.exists(self.root) and len(os.listdir(self.root)) > 0 |
| 231 | + |
| 232 | + def download(self) -> None: |
| 233 | + if self._check_integrity(): |
| 234 | + raise RuntimeError( |
| 235 | + f"The directory {self.root} already exists. " |
| 236 | + f"If you want to re-download or re-extract the images, delete the directory." |
| 237 | + ) |
| 238 | + |
| 239 | + base_root = os.path.dirname(self.root) |
| 240 | + |
| 241 | + download_and_extract_archive( |
| 242 | + DATASET_URLS[self.version], |
| 243 | + base_root, |
| 244 | + filename=f"{self.version}.tgz", |
| 245 | + md5=DATASET_MD5[self.version]) |
| 246 | + |
| 247 | + orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz")) |
| 248 | + if not os.path.exists(orig_dir_name): |
| 249 | + raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}") |
| 250 | + os.rename(orig_dir_name, self.root) |
| 251 | + print(f"Dataset version '{self.version}' has been downloaded and prepared for use") |
0 commit comments