Skip to content

Commit 5130e47

Browse files
authored
Merge branch 'main' into improved-affine
2 parents 2e04aa2 + bf073e7 commit 5130e47

File tree

6 files changed

+251
-3
lines changed

6 files changed

+251
-3
lines changed

docs/source/datasets.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4141
Country211
4242
DTD
4343
EMNIST
44+
EuroSAT
4445
FakeData
4546
FashionMNIST
4647
FER2013
@@ -74,6 +75,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
7475
SBU
7576
SEMEION
7677
Sintel
78+
StanfordCars
7779
STL10
7880
SUN397
7981
SVHN

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ architectures, and common image transformations for computer vision.
3131
:maxdepth: 2
3232
:caption: Package Reference
3333

34-
datasets
3534
transforms
3635
models
37-
feature_extraction
36+
datasets
37+
utils
3838
ops
3939
io
40-
utils
40+
feature_extraction
4141

4242
.. toctree::
4343
:maxdepth: 1

test/test_datasets.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,27 @@ def inject_fake_data(self, tmpdir, config):
21692169
return num_sequences * (num_examples_per_sequence - 1)
21702170

21712171

2172+
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
2173+
DATASET_CLASS = datasets.EuroSAT
2174+
FEATURE_TYPES = (PIL.Image.Image, int)
2175+
2176+
def inject_fake_data(self, tmpdir, config):
2177+
data_folder = os.path.join(tmpdir, "eurosat", "2750")
2178+
os.makedirs(data_folder)
2179+
2180+
num_examples_per_class = 3
2181+
classes = ("AnnualCrop", "Forest")
2182+
for cls in classes:
2183+
datasets_utils.create_image_folder(
2184+
root=data_folder,
2185+
name=cls,
2186+
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
2187+
num_examples=num_examples_per_class,
2188+
)
2189+
2190+
return len(classes) * num_examples_per_class
2191+
2192+
21722193
class Food101TestCase(datasets_utils.ImageDatasetTestCase):
21732194
DATASET_CLASS = datasets.Food101
21742195
FEATURE_TYPES = (PIL.Image.Image, int)
@@ -2514,6 +2535,50 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25142535
return (image_id, class_id, species, breed_id)
25152536

25162537

2538+
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
2539+
DATASET_CLASS = datasets.StanfordCars
2540+
REQUIRED_PACKAGES = ("scipy",)
2541+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
2542+
2543+
def inject_fake_data(self, tmpdir, config):
2544+
import scipy.io as io
2545+
from numpy.core.records import fromarrays
2546+
2547+
num_examples = {"train": 5, "test": 7}[config["split"]]
2548+
num_classes = 3
2549+
base_folder = pathlib.Path(tmpdir) / "stanford_cars"
2550+
2551+
devkit = base_folder / "devkit"
2552+
devkit.mkdir(parents=True)
2553+
2554+
if config["split"] == "train":
2555+
images_folder_name = "cars_train"
2556+
annotations_mat_path = devkit / "cars_train_annos.mat"
2557+
else:
2558+
images_folder_name = "cars_test"
2559+
annotations_mat_path = base_folder / "cars_test_annos_withlabels.mat"
2560+
2561+
datasets_utils.create_image_folder(
2562+
root=base_folder,
2563+
name=images_folder_name,
2564+
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
2565+
num_examples=num_examples,
2566+
)
2567+
2568+
classes = np.random.randint(1, num_classes + 1, num_examples, dtype=np.uint8)
2569+
fnames = [f"{i:5d}.jpg" for i in range(num_examples)]
2570+
rec_array = fromarrays(
2571+
[classes, fnames],
2572+
names=["class", "fname"],
2573+
)
2574+
io.savemat(annotations_mat_path, {"annotations": rec_array})
2575+
2576+
random_class_names = ["random_name"] * num_classes
2577+
io.savemat(devkit / "cars_meta.mat", {"class_names": random_class_names})
2578+
2579+
return num_examples
2580+
2581+
25172582
class Country211TestCase(datasets_utils.ImageDatasetTestCase):
25182583
DATASET_CLASS = datasets.Country211
25192584

torchvision/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .coco import CocoCaptions, CocoDetection
88
from .country211 import Country211
99
from .dtd import DTD
10+
from .eurosat import EuroSAT
1011
from .fakedata import FakeData
1112
from .fer2013 import FER2013
1213
from .fgvc_aircraft import FGVCAircraft
@@ -31,6 +32,7 @@
3132
from .sbd import SBDataset
3233
from .sbu import SBU
3334
from .semeion import SEMEION
35+
from .stanford_cars import StanfordCars
3436
from .stl10 import STL10
3537
from .sun397 import SUN397
3638
from .svhn import SVHN
@@ -55,6 +57,7 @@
5557
"QMNIST",
5658
"MNIST",
5759
"KMNIST",
60+
"StanfordCars",
5861
"STL10",
5962
"SUN397",
6063
"SVHN",
@@ -98,4 +101,5 @@
98101
"OxfordIIITPet",
99102
"Country211",
100103
"FGVCAircraft",
104+
"EuroSAT",
101105
)

torchvision/datasets/eurosat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from typing import Any
3+
4+
from .folder import ImageFolder
5+
from .utils import download_and_extract_archive
6+
7+
8+
class EuroSAT(ImageFolder):
9+
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
10+
11+
Args:
12+
root (string): Root directory of dataset where ``root/eurosat`` exists.
13+
download (bool, optional): If True, downloads the dataset from the internet and
14+
puts it in root directory. If dataset is already downloaded, it is not
15+
downloaded again. Default is False.
16+
transform (callable, optional): A function/transform that takes in an PIL image
17+
and returns a transformed version. E.g, ``transforms.RandomCrop``
18+
target_transform (callable, optional): A function/transform that takes in the
19+
target and transforms it.
20+
"""
21+
22+
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
23+
md5 = "c8fa014336c82ac7804f0398fcb19387"
24+
25+
def __init__(
26+
self,
27+
root: str,
28+
download: bool = False,
29+
**kwargs: Any,
30+
) -> None:
31+
self.root = os.path.expanduser(root)
32+
self._base_folder = os.path.join(self.root, "eurosat")
33+
self._data_folder = os.path.join(self._base_folder, "2750")
34+
35+
if download:
36+
self.download()
37+
38+
if not self._check_exists():
39+
raise RuntimeError("Dataset not found. You can use download=True to download it")
40+
41+
super().__init__(self._data_folder, **kwargs)
42+
self.root = os.path.expanduser(root)
43+
44+
def __len__(self) -> int:
45+
return len(self.samples)
46+
47+
def _check_exists(self) -> bool:
48+
return os.path.exists(self._data_folder)
49+
50+
def download(self) -> None:
51+
52+
if self._check_exists():
53+
return
54+
55+
os.makedirs(self._base_folder, exist_ok=True)
56+
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)

torchvision/datasets/stanford_cars.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pathlib
2+
from typing import Callable, Optional, Any, Tuple
3+
4+
from PIL import Image
5+
6+
from .utils import download_and_extract_archive, download_url, verify_str_arg
7+
from .vision import VisionDataset
8+
9+
10+
class StanfordCars(VisionDataset):
11+
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
12+
13+
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14+
split into 8,144 training images and 8,041 testing images, where each class
15+
has been split roughly in a 50-50 split
16+
17+
.. note::
18+
19+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
20+
21+
Args:
22+
root (string): Root directory of dataset
23+
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
24+
transform (callable, optional): A function/transform that takes in an PIL image
25+
and returns a transformed version. E.g, ``transforms.RandomCrop``
26+
target_transform (callable, optional): A function/transform that takes in the
27+
target and transforms it.
28+
download (bool, optional): If True, downloads the dataset from the internet and
29+
puts it in root directory. If dataset is already downloaded, it is not
30+
downloaded again."""
31+
32+
def __init__(
33+
self,
34+
root: str,
35+
split: str = "train",
36+
transform: Optional[Callable] = None,
37+
target_transform: Optional[Callable] = None,
38+
download: bool = False,
39+
) -> None:
40+
41+
try:
42+
import scipy.io as sio
43+
except ImportError:
44+
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
45+
46+
super().__init__(root, transform=transform, target_transform=target_transform)
47+
48+
self._split = verify_str_arg(split, "split", ("train", "test"))
49+
self._base_folder = pathlib.Path(root) / "stanford_cars"
50+
devkit = self._base_folder / "devkit"
51+
52+
if self._split == "train":
53+
self._annotations_mat_path = devkit / "cars_train_annos.mat"
54+
self._images_base_path = self._base_folder / "cars_train"
55+
else:
56+
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
57+
self._images_base_path = self._base_folder / "cars_test"
58+
59+
if download:
60+
self.download()
61+
62+
if not self._check_exists():
63+
raise RuntimeError("Dataset not found. You can use download=True to download it")
64+
65+
self._samples = [
66+
(
67+
str(self._images_base_path / annotation["fname"]),
68+
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
69+
)
70+
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
71+
]
72+
73+
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
74+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
75+
76+
def __len__(self) -> int:
77+
return len(self._samples)
78+
79+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
80+
"""Returns pil_image and class_id for given index"""
81+
image_path, target = self._samples[idx]
82+
pil_image = Image.open(image_path).convert("RGB")
83+
84+
if self.transform is not None:
85+
pil_image = self.transform(pil_image)
86+
if self.target_transform is not None:
87+
target = self.target_transform(target)
88+
return pil_image, target
89+
90+
def download(self) -> None:
91+
if self._check_exists():
92+
return
93+
94+
download_and_extract_archive(
95+
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
96+
download_root=str(self._base_folder),
97+
md5="c3b158d763b6e2245038c8ad08e45376",
98+
)
99+
if self._split == "train":
100+
download_and_extract_archive(
101+
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
102+
download_root=str(self._base_folder),
103+
md5="065e5b463ae28d29e77c1b4b166cfe61",
104+
)
105+
else:
106+
download_and_extract_archive(
107+
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
108+
download_root=str(self._base_folder),
109+
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
110+
)
111+
download_url(
112+
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
113+
root=str(self._base_folder),
114+
md5="b0a2b23655a3edd16d84508592a98d10",
115+
)
116+
117+
def _check_exists(self) -> bool:
118+
if not (self._base_folder / "devkit").is_dir():
119+
return False
120+
121+
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()

0 commit comments

Comments
 (0)