|
| 1 | +import pathlib |
| 2 | +from typing import Callable, Optional, Any, Tuple |
| 3 | + |
| 4 | +from PIL import Image |
| 5 | + |
| 6 | +from .utils import download_and_extract_archive, download_url, verify_str_arg |
| 7 | +from .vision import VisionDataset |
| 8 | + |
| 9 | + |
| 10 | +class StanfordCars(VisionDataset): |
| 11 | + """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset |
| 12 | +
|
| 13 | + The Cars dataset contains 16,185 images of 196 classes of cars. The data is |
| 14 | + split into 8,144 training images and 8,041 testing images, where each class |
| 15 | + has been split roughly in a 50-50 split |
| 16 | +
|
| 17 | + .. note:: |
| 18 | +
|
| 19 | + This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. |
| 20 | +
|
| 21 | + Args: |
| 22 | + root (string): Root directory of dataset |
| 23 | + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. |
| 24 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 25 | + and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| 26 | + target_transform (callable, optional): A function/transform that takes in the |
| 27 | + target and transforms it. |
| 28 | + download (bool, optional): If True, downloads the dataset from the internet and |
| 29 | + puts it in root directory. If dataset is already downloaded, it is not |
| 30 | + downloaded again.""" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + root: str, |
| 35 | + split: str = "train", |
| 36 | + transform: Optional[Callable] = None, |
| 37 | + target_transform: Optional[Callable] = None, |
| 38 | + download: bool = False, |
| 39 | + ) -> None: |
| 40 | + |
| 41 | + try: |
| 42 | + import scipy.io as sio |
| 43 | + except ImportError: |
| 44 | + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") |
| 45 | + |
| 46 | + super().__init__(root, transform=transform, target_transform=target_transform) |
| 47 | + |
| 48 | + self._split = verify_str_arg(split, "split", ("train", "test")) |
| 49 | + self._base_folder = pathlib.Path(root) / "stanford_cars" |
| 50 | + devkit = self._base_folder / "devkit" |
| 51 | + |
| 52 | + if self._split == "train": |
| 53 | + self._annotations_mat_path = devkit / "cars_train_annos.mat" |
| 54 | + self._images_base_path = self._base_folder / "cars_train" |
| 55 | + else: |
| 56 | + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" |
| 57 | + self._images_base_path = self._base_folder / "cars_test" |
| 58 | + |
| 59 | + if download: |
| 60 | + self.download() |
| 61 | + |
| 62 | + if not self._check_exists(): |
| 63 | + raise RuntimeError("Dataset not found. You can use download=True to download it") |
| 64 | + |
| 65 | + self._samples = [ |
| 66 | + ( |
| 67 | + str(self._images_base_path / annotation["fname"]), |
| 68 | + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 |
| 69 | + ) |
| 70 | + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] |
| 71 | + ] |
| 72 | + |
| 73 | + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() |
| 74 | + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} |
| 75 | + |
| 76 | + def __len__(self) -> int: |
| 77 | + return len(self._samples) |
| 78 | + |
| 79 | + def __getitem__(self, idx: int) -> Tuple[Any, Any]: |
| 80 | + """Returns pil_image and class_id for given index""" |
| 81 | + image_path, target = self._samples[idx] |
| 82 | + pil_image = Image.open(image_path).convert("RGB") |
| 83 | + |
| 84 | + if self.transform is not None: |
| 85 | + pil_image = self.transform(pil_image) |
| 86 | + if self.target_transform is not None: |
| 87 | + target = self.target_transform(target) |
| 88 | + return pil_image, target |
| 89 | + |
| 90 | + def download(self) -> None: |
| 91 | + if self._check_exists(): |
| 92 | + return |
| 93 | + |
| 94 | + download_and_extract_archive( |
| 95 | + url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", |
| 96 | + download_root=str(self._base_folder), |
| 97 | + md5="c3b158d763b6e2245038c8ad08e45376", |
| 98 | + ) |
| 99 | + if self._split == "train": |
| 100 | + download_and_extract_archive( |
| 101 | + url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", |
| 102 | + download_root=str(self._base_folder), |
| 103 | + md5="065e5b463ae28d29e77c1b4b166cfe61", |
| 104 | + ) |
| 105 | + else: |
| 106 | + download_and_extract_archive( |
| 107 | + url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", |
| 108 | + download_root=str(self._base_folder), |
| 109 | + md5="4ce7ebf6a94d07f1952d94dd34c4d501", |
| 110 | + ) |
| 111 | + download_url( |
| 112 | + url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", |
| 113 | + root=str(self._base_folder), |
| 114 | + md5="b0a2b23655a3edd16d84508592a98d10", |
| 115 | + ) |
| 116 | + |
| 117 | + def _check_exists(self) -> bool: |
| 118 | + if not (self._base_folder / "devkit").is_dir(): |
| 119 | + return False |
| 120 | + |
| 121 | + return self._annotations_mat_path.exists() and self._images_base_path.is_dir() |
0 commit comments