Skip to content

Commit 8ba482a

Browse files
Add support for Stanford cars dataset (#5166)
* [WIP] *added stanford_cars * [WIP] added stanfordCars to docs * [WIP] minor edits * [WIP] minor edits * edited StanfordCars class * Adding Testcase for stanford cars * Added Testcase for stanford cars * Added Testcase for stanford cars * minor edit * made changes as per the suggestions * fixed typo in naming stanford_cars.py * cars_meta.mat file will be created in test * Some cleanups * Sigh * don't convert to strings Co-authored-by: Nicolas Hug <[email protected]>
1 parent 57a77c4 commit 8ba482a

File tree

4 files changed

+168
-0
lines changed

4 files changed

+168
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
7575
SBU
7676
SEMEION
7777
Sintel
78+
StanfordCars
7879
STL10
7980
SUN397
8081
SVHN

test/test_datasets.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,6 +2535,50 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25352535
return (image_id, class_id, species, breed_id)
25362536

25372537

2538+
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
2539+
DATASET_CLASS = datasets.StanfordCars
2540+
REQUIRED_PACKAGES = ("scipy",)
2541+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2542+
2543+
def inject_fake_data(self, tmpdir, config):
2544+
import scipy.io as io
2545+
from numpy.core.records import fromarrays
2546+
2547+
num_examples = {"train": 5, "test": 7}[config["split"]]
2548+
num_classes = 3
2549+
base_folder = pathlib.Path(tmpdir) / "stanford_cars"
2550+
2551+
devkit = base_folder / "devkit"
2552+
devkit.mkdir(parents=True)
2553+
2554+
if config["split"] == "train":
2555+
images_folder_name = "cars_train"
2556+
annotations_mat_path = devkit / "cars_train_annos.mat"
2557+
else:
2558+
images_folder_name = "cars_test"
2559+
annotations_mat_path = base_folder / "cars_test_annos_withlabels.mat"
2560+
2561+
datasets_utils.create_image_folder(
2562+
root=base_folder,
2563+
name=images_folder_name,
2564+
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
2565+
num_examples=num_examples,
2566+
)
2567+
2568+
classes = np.random.randint(1, num_classes + 1, num_examples, dtype=np.uint8)
2569+
fnames = [f"{i:5d}.jpg" for i in range(num_examples)]
2570+
rec_array = fromarrays(
2571+
[classes, fnames],
2572+
names=["class", "fname"],
2573+
)
2574+
io.savemat(annotations_mat_path, {"annotations": rec_array})
2575+
2576+
random_class_names = ["random_name"] * num_classes
2577+
io.savemat(devkit / "cars_meta.mat", {"class_names": random_class_names})
2578+
2579+
return num_examples
2580+
2581+
25382582
class Country211TestCase(datasets_utils.ImageDatasetTestCase):
25392583
DATASET_CLASS = datasets.Country211
25402584

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .sbd import SBDataset
3333
from .sbu import SBU
3434
from .semeion import SEMEION
35+
from .stanford_cars import StanfordCars
3536
from .stl10 import STL10
3637
from .sun397 import SUN397
3738
from .svhn import SVHN
@@ -56,6 +57,7 @@
5657
"QMNIST",
5758
"MNIST",
5859
"KMNIST",
60+
"StanfordCars",
5961
"STL10",
6062
"SUN397",
6163
"SVHN",

torchvision/datasets/stanford_cars.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pathlib
2+
from typing import Callable, Optional, Any, Tuple
3+
4+
from PIL import Image
5+
6+
from .utils import download_and_extract_archive, download_url, verify_str_arg
7+
from .vision import VisionDataset
8+
9+
10+
class StanfordCars(VisionDataset):
11+
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
12+
13+
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14+
split into 8,144 training images and 8,041 testing images, where each class
15+
has been split roughly in a 50-50 split
16+
17+
.. note::
18+
19+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
20+
21+
Args:
22+
root (string): Root directory of dataset
23+
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
24+
transform (callable, optional): A function/transform that takes in an PIL image
25+
and returns a transformed version. E.g, ``transforms.RandomCrop``
26+
target_transform (callable, optional): A function/transform that takes in the
27+
target and transforms it.
28+
download (bool, optional): If True, downloads the dataset from the internet and
29+
puts it in root directory. If dataset is already downloaded, it is not
30+
downloaded again."""
31+
32+
def __init__(
33+
self,
34+
root: str,
35+
split: str = "train",
36+
transform: Optional[Callable] = None,
37+
target_transform: Optional[Callable] = None,
38+
download: bool = False,
39+
) -> None:
40+
41+
try:
42+
import scipy.io as sio
43+
except ImportError:
44+
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
45+
46+
super().__init__(root, transform=transform, target_transform=target_transform)
47+
48+
self._split = verify_str_arg(split, "split", ("train", "test"))
49+
self._base_folder = pathlib.Path(root) / "stanford_cars"
50+
devkit = self._base_folder / "devkit"
51+
52+
if self._split == "train":
53+
self._annotations_mat_path = devkit / "cars_train_annos.mat"
54+
self._images_base_path = self._base_folder / "cars_train"
55+
else:
56+
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
57+
self._images_base_path = self._base_folder / "cars_test"
58+
59+
if download:
60+
self.download()
61+
62+
if not self._check_exists():
63+
raise RuntimeError("Dataset not found. You can use download=True to download it")
64+
65+
self._samples = [
66+
(
67+
str(self._images_base_path / annotation["fname"]),
68+
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
69+
)
70+
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
71+
]
72+
73+
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
74+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
75+
76+
def __len__(self) -> int:
77+
return len(self._samples)
78+
79+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
80+
"""Returns pil_image and class_id for given index"""
81+
image_path, target = self._samples[idx]
82+
pil_image = Image.open(image_path).convert("RGB")
83+
84+
if self.transform is not None:
85+
pil_image = self.transform(pil_image)
86+
if self.target_transform is not None:
87+
target = self.target_transform(target)
88+
return pil_image, target
89+
90+
def download(self) -> None:
91+
if self._check_exists():
92+
return
93+
94+
download_and_extract_archive(
95+
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
96+
download_root=str(self._base_folder),
97+
md5="c3b158d763b6e2245038c8ad08e45376",
98+
)
99+
if self._split == "train":
100+
download_and_extract_archive(
101+
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
102+
download_root=str(self._base_folder),
103+
md5="065e5b463ae28d29e77c1b4b166cfe61",
104+
)
105+
else:
106+
download_and_extract_archive(
107+
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
108+
download_root=str(self._base_folder),
109+
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
110+
)
111+
download_url(
112+
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
113+
root=str(self._base_folder),
114+
md5="b0a2b23655a3edd16d84508592a98d10",
115+
)
116+
117+
def _check_exists(self) -> bool:
118+
if not (self._base_folder / "devkit").is_dir():
119+
return False
120+
121+
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()

0 commit comments

Comments
 (0)