|
| 1 | +from __future__ import print_function |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import torch |
| 5 | +from .folder import ImageFolder |
| 6 | +from .utils import check_integrity, download_url |
| 7 | + |
| 8 | +ARCHIVE_DICT = { |
| 9 | + 'train': { |
| 10 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', |
| 11 | + 'md5': '1d675b47d978889d74fa0da5fadfb00e', |
| 12 | + }, |
| 13 | + 'val': { |
| 14 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', |
| 15 | + 'md5': '29b22e2961454d5413ddabcf34fc5622', |
| 16 | + }, |
| 17 | + 'devkit': { |
| 18 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', |
| 19 | + 'md5': 'fa75699e90414af021442c21a62c3abf', |
| 20 | + } |
| 21 | +} |
| 22 | + |
| 23 | +META_DICT = { |
| 24 | + 'filename': 'meta.bin', |
| 25 | + 'md5': '7e0d3cf156177e4fc47011cdd30ce706', |
| 26 | +} |
| 27 | + |
| 28 | + |
| 29 | +class ImageNet(ImageFolder): |
| 30 | + """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset. |
| 31 | +
|
| 32 | + Args: |
| 33 | + root (string): Root directory of the ImageNet Dataset. |
| 34 | + split (string, optional): The dataset split, supports ``train``, or ``val``. |
| 35 | + download (bool, optional): If true, downloads the dataset from the internet and |
| 36 | + puts it in root directory. If dataset is already downloaded, it is not |
| 37 | + downloaded again. |
| 38 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 39 | + and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| 40 | + target_transform (callable, optional): A function/transform that takes in the |
| 41 | + target and transforms it. |
| 42 | + loader (callable, optional): A function to load an image given its path. |
| 43 | +
|
| 44 | + Attributes: |
| 45 | + classes (list): List of the class names. |
| 46 | + class_to_idx (dict): Dict with items (class_name, class_index). |
| 47 | + wnids (list): List of the WordNet IDs. |
| 48 | + class_to_idx (dict): Dict with items (wordnet_id, wordnet_id_index). |
| 49 | + imgs (list): List of (image path, class_index) tuples |
| 50 | + targets (list): The class_index value for each image in the dataset |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, root, split='train', download=False, **kwargs): |
| 54 | + root = self.root = os.path.expanduser(root) |
| 55 | + self.split = self._verify_split(split) |
| 56 | + |
| 57 | + if download: |
| 58 | + self.download() |
| 59 | + wnid_to_classes = self._load_meta_file()[0] |
| 60 | + |
| 61 | + super(ImageNet, self).__init__(self.split_folder, **kwargs) |
| 62 | + self.root = root |
| 63 | + |
| 64 | + idcs = [idx for _, idx in self.imgs] |
| 65 | + self.wnids = self.classes |
| 66 | + self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} |
| 67 | + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] |
| 68 | + self.class_to_idx = {cls: idx |
| 69 | + for clss, idx in zip(self.classes, idcs) |
| 70 | + for cls in clss} |
| 71 | + |
| 72 | + def download(self): |
| 73 | + if not self._check_meta_file_integrity(): |
| 74 | + tmpdir = os.path.join(self.root, 'tmp') |
| 75 | + |
| 76 | + archive_dict = ARCHIVE_DICT['devkit'] |
| 77 | + download_and_extract_tar(archive_dict['url'], self.root, |
| 78 | + extract_root=tmpdir, |
| 79 | + md5=archive_dict['md5']) |
| 80 | + devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] |
| 81 | + meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) |
| 82 | + self._save_meta_file(*meta) |
| 83 | + |
| 84 | + shutil.rmtree(tmpdir) |
| 85 | + |
| 86 | + if not os.path.isdir(self.split_folder): |
| 87 | + archive_dict = ARCHIVE_DICT[self.split] |
| 88 | + download_and_extract_tar(archive_dict['url'], self.root, |
| 89 | + extract_root=self.split_folder, |
| 90 | + md5=archive_dict['md5']) |
| 91 | + |
| 92 | + if self.split == 'train': |
| 93 | + prepare_train_folder(self.split_folder) |
| 94 | + elif self.split == 'val': |
| 95 | + val_wnids = self._load_meta_file()[1] |
| 96 | + prepare_val_folder(self.split_folder, val_wnids) |
| 97 | + else: |
| 98 | + msg = ("You set download=True, but a folder '{}' already exist in " |
| 99 | + "the root directory. If you want to re-download or re-extract the " |
| 100 | + "archive, delete the folder.") |
| 101 | + print(msg.format(self.split)) |
| 102 | + |
| 103 | + @property |
| 104 | + def meta_file(self): |
| 105 | + return os.path.join(self.root, META_DICT['filename']) |
| 106 | + |
| 107 | + def _check_meta_file_integrity(self): |
| 108 | + return check_integrity(self.meta_file, META_DICT['md5']) |
| 109 | + |
| 110 | + def _load_meta_file(self): |
| 111 | + if self._check_meta_file_integrity(): |
| 112 | + return torch.load(self.meta_file) |
| 113 | + else: |
| 114 | + raise RuntimeError("Meta file not found or corrupted.", |
| 115 | + "You can use download=True to create it.") |
| 116 | + |
| 117 | + def _save_meta_file(self, wnid_to_class, val_wnids): |
| 118 | + torch.save((wnid_to_class, val_wnids), self.meta_file) |
| 119 | + |
| 120 | + def _verify_split(self, split): |
| 121 | + if split not in self.valid_splits: |
| 122 | + msg = "Unknown split {} .".format(split) |
| 123 | + msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) |
| 124 | + raise ValueError(msg) |
| 125 | + return split |
| 126 | + |
| 127 | + @property |
| 128 | + def valid_splits(self): |
| 129 | + return 'train', 'val' |
| 130 | + |
| 131 | + @property |
| 132 | + def split_folder(self): |
| 133 | + return os.path.join(self.root, self.split) |
| 134 | + |
| 135 | + def __repr__(self): |
| 136 | + head = "Dataset " + self.__class__.__name__ |
| 137 | + body = ["Number of datapoints: {}".format(self.__len__())] |
| 138 | + if self.root is not None: |
| 139 | + body.append("Root location: {}".format(self.root)) |
| 140 | + body += ["Split: {}".format(self.split)] |
| 141 | + if hasattr(self, 'transform') and self.transform is not None: |
| 142 | + body += self._format_transform_repr(self.transform, |
| 143 | + "Transforms: ") |
| 144 | + if hasattr(self, 'target_transform') and self.target_transform is not None: |
| 145 | + body += self._format_transform_repr(self.target_transform, |
| 146 | + "Target transforms: ") |
| 147 | + lines = [head] + [" " * 4 + line for line in body] |
| 148 | + return '\n'.join(lines) |
| 149 | + |
| 150 | + def _format_transform_repr(self, transform, head): |
| 151 | + lines = transform.__repr__().splitlines() |
| 152 | + return (["{}{}".format(head, lines[0])] + |
| 153 | + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) |
| 154 | + |
| 155 | + |
| 156 | +def extract_tar(src, dest=None, gzip=None, delete=False): |
| 157 | + import tarfile |
| 158 | + |
| 159 | + if dest is None: |
| 160 | + dest = os.path.dirname(src) |
| 161 | + if gzip is None: |
| 162 | + gzip = src.lower().endswith('.gz') |
| 163 | + |
| 164 | + mode = 'r:gz' if gzip else 'r' |
| 165 | + with tarfile.open(src, mode) as tarfh: |
| 166 | + tarfh.extractall(path=dest) |
| 167 | + |
| 168 | + if delete: |
| 169 | + os.remove(src) |
| 170 | + |
| 171 | + |
| 172 | +def download_and_extract_tar(url, download_root, extract_root=None, filename=None, |
| 173 | + md5=None, **kwargs): |
| 174 | + download_root = os.path.expanduser(download_root) |
| 175 | + if extract_root is None: |
| 176 | + extract_root = extract_root |
| 177 | + if filename is None: |
| 178 | + filename = os.path.basename(url) |
| 179 | + |
| 180 | + if not check_integrity(os.path.join(download_root, filename), md5): |
| 181 | + download_url(url, download_root, filename=filename, md5=md5) |
| 182 | + |
| 183 | + extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) |
| 184 | + |
| 185 | + |
| 186 | +def parse_devkit(root): |
| 187 | + idx_to_wnid, wnid_to_classes = parse_meta(root) |
| 188 | + val_idcs = parse_val_groundtruth(root) |
| 189 | + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] |
| 190 | + return wnid_to_classes, val_wnids |
| 191 | + |
| 192 | + |
| 193 | +def parse_meta(devkit_root, path='data', filename='meta.mat'): |
| 194 | + import scipy.io as sio |
| 195 | + |
| 196 | + metafile = os.path.join(devkit_root, path, filename) |
| 197 | + meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] |
| 198 | + nums_children = list(zip(*meta))[4] |
| 199 | + meta = [meta[idx] for idx, num_children in enumerate(nums_children) |
| 200 | + if num_children == 0] |
| 201 | + idcs, wnids, classes = list(zip(*meta))[:3] |
| 202 | + classes = [tuple(clss.split(', ')) for clss in classes] |
| 203 | + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} |
| 204 | + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} |
| 205 | + return idx_to_wnid, wnid_to_classes |
| 206 | + |
| 207 | + |
| 208 | +def parse_val_groundtruth(devkit_root, path='data', |
| 209 | + filename='ILSVRC2012_validation_ground_truth.txt'): |
| 210 | + with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: |
| 211 | + val_idcs = txtfh.readlines() |
| 212 | + return [int(val_idx) for val_idx in val_idcs] |
| 213 | + |
| 214 | + |
| 215 | +def prepare_train_folder(folder): |
| 216 | + for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: |
| 217 | + extract_tar(archive, os.path.splitext(archive)[0], delete=True) |
| 218 | + |
| 219 | + |
| 220 | +def prepare_val_folder(folder, wnids): |
| 221 | + img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) |
| 222 | + |
| 223 | + for wnid in set(wnids): |
| 224 | + os.mkdir(os.path.join(folder, wnid)) |
| 225 | + |
| 226 | + for wnid, img_file in zip(wnids, img_files): |
| 227 | + shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) |
| 228 | + |
| 229 | + |
| 230 | +def _splitexts(root): |
| 231 | + exts = [] |
| 232 | + ext = '.' |
| 233 | + while ext: |
| 234 | + root, ext = os.path.splitext(root) |
| 235 | + exts.append(ext) |
| 236 | + return root, ''.join(reversed(exts)) |
0 commit comments