Skip to content

Commit daf07e0

Browse files
authored
Merge branch 'main' into models/convnext
2 parents 9e6fda1 + c27bed4 commit daf07e0

36 files changed

+115
-126
lines changed

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,7 @@ ignore_missing_imports = True
117117
[mypy-torchdata.*]
118118

119119
ignore_missing_imports = True
120+
121+
[mypy-h5py.*]
122+
123+
ignore_missing_imports = True

test/test_datasets.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,11 +2281,6 @@ def inject_fake_data(self, tmpdir: str, config):
22812281
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
22822282
DATASET_CLASS = datasets.SUN397
22832283

2284-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2285-
split=("train", "test"),
2286-
partition=(1, 10, None),
2287-
)
2288-
22892284
def inject_fake_data(self, tmpdir: str, config):
22902285
data_dir = pathlib.Path(tmpdir) / "SUN397"
22912286
data_dir.mkdir()
@@ -2308,18 +2303,7 @@ def inject_fake_data(self, tmpdir: str, config):
23082303
with open(data_dir / "ClassName.txt", "w") as file:
23092304
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
23102305

2311-
if config["partition"] is not None:
2312-
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
2313-
2314-
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
2315-
file.writelines(
2316-
"\n".join(
2317-
f"/{f_path.relative_to(data_dir).as_posix()}"
2318-
for f_path in random.choices(im_paths, k=num_samples)
2319-
)
2320-
)
2321-
else:
2322-
num_samples = len(im_paths)
2306+
num_samples = len(im_paths)
23232307

23242308
return num_samples
23252309

@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
23972381
DATASET_CLASS = datasets.GTSRB
23982382
FEATURE_TYPES = (PIL.Image.Image, int)
23992383

2400-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2384+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
24012385

24022386
def inject_fake_data(self, tmpdir: str, config):
2403-
root_folder = os.path.join(tmpdir, "GTSRB")
2387+
root_folder = os.path.join(tmpdir, "gtsrb")
24042388
os.makedirs(root_folder, exist_ok=True)
24052389

24062390
# Train data
2407-
train_folder = os.path.join(root_folder, "Training")
2391+
train_folder = os.path.join(root_folder, "GTSRB", "Training")
24082392
os.makedirs(train_folder, exist_ok=True)
24092393

2410-
num_examples = 3
2394+
num_examples = 3 if config["split"] == "train" else 4
24112395
classes = ("00000", "00042", "00012")
24122396
for class_idx in classes:
24132397
datasets_utils.create_image_folder(
@@ -2419,7 +2403,7 @@ def inject_fake_data(self, tmpdir: str, config):
24192403

24202404
total_number_of_examples = num_examples * len(classes)
24212405
# Test data
2422-
test_folder = os.path.join(root_folder, "Final_Test", "Images")
2406+
test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
24232407
os.makedirs(test_folder, exist_ok=True)
24242408

24252409
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:

test/test_prototype_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_naming_conventions(model_fn):
9797
)
9898
@run_if_test_with_prototype
9999
def test_schema_meta_validation(model_fn):
100-
classification_fields = ["size", "categories", "acc@1", "acc@5"]
100+
classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"]
101101
defaults = {
102102
"all": ["task", "architecture", "publication_year", "interpolation", "recipe", "num_params"],
103103
"models": classification_fields,

torchvision/datasets/clevr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
split: str = "train",
3535
transform: Optional[Callable] = None,
3636
target_transform: Optional[Callable] = None,
37-
download: bool = True,
37+
download: bool = False,
3838
) -> None:
3939
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4040
super().__init__(root, transform=transform, target_transform=target_transform)

torchvision/datasets/country211.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
split: str = "train",
3333
transform: Optional[Callable] = None,
3434
target_transform: Optional[Callable] = None,
35-
download: bool = True,
35+
download: bool = False,
3636
) -> None:
3737
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3838

torchvision/datasets/dtd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ class DTD(VisionDataset):
2121
The partition only changes which split each image belongs to. Thus, regardless of the selected
2222
partition, combining all splits will result in all images.
2323
24-
download (bool, optional): If True, downloads the dataset from the internet and
25-
puts it in root directory. If dataset is already downloaded, it is not
26-
downloaded again.
2724
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2825
version. E.g, ``transforms.RandomCrop``.
2926
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
download (bool, optional): If True, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again. Default is False.
3030
"""
3131

3232
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
@@ -37,9 +37,9 @@ def __init__(
3737
root: str,
3838
split: str = "train",
3939
partition: int = 1,
40-
download: bool = True,
4140
transform: Optional[Callable] = None,
4241
target_transform: Optional[Callable] = None,
42+
download: bool = False,
4343
) -> None:
4444
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
4545
if not isinstance(partition, int) and not (1 <= partition <= 10):

torchvision/datasets/eurosat.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any
2+
from typing import Callable, Optional
33

44
from .folder import ImageFolder
55
from .utils import download_and_extract_archive
@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):
1010
1111
Args:
1212
root (string): Root directory of dataset where ``root/eurosat`` exists.
13-
download (bool, optional): If True, downloads the dataset from the internet and
14-
puts it in root directory. If dataset is already downloaded, it is not
15-
downloaded again. Default is False.
1613
transform (callable, optional): A function/transform that takes in an PIL image
1714
and returns a transformed version. E.g, ``transforms.RandomCrop``
1815
target_transform (callable, optional): A function/transform that takes in the
1916
target and transforms it.
17+
download (bool, optional): If True, downloads the dataset from the internet and
18+
puts it in root directory. If dataset is already downloaded, it is not
19+
downloaded again. Default is False.
2020
"""
2121

22-
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
23-
md5 = "c8fa014336c82ac7804f0398fcb19387"
24-
2522
def __init__(
2623
self,
2724
root: str,
25+
transform: Optional[Callable] = None,
26+
target_transform: Optional[Callable] = None,
2827
download: bool = False,
29-
**kwargs: Any,
3028
) -> None:
3129
self.root = os.path.expanduser(root)
3230
self._base_folder = os.path.join(self.root, "eurosat")
@@ -38,7 +36,7 @@ def __init__(
3836
if not self._check_exists():
3937
raise RuntimeError("Dataset not found. You can use download=True to download it")
4038

41-
super().__init__(self._data_folder, **kwargs)
39+
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
4240
self.root = os.path.expanduser(root)
4341

4442
def __len__(self) -> int:
@@ -53,4 +51,8 @@ def download(self) -> None:
5351
return
5452

5553
os.makedirs(self._base_folder, exist_ok=True)
56-
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
54+
download_and_extract_archive(
55+
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
56+
download_root=self._base_folder,
57+
md5="c8fa014336c82ac7804f0398fcb19387",
58+
)

torchvision/datasets/fgvc_aircraft.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset):
2626
root (string): Root directory of the FGVC Aircraft dataset.
2727
split (string, optional): The dataset split, supports ``train``, ``val``,
2828
``trainval`` and ``test``.
29-
download (bool, optional): If True, downloads the dataset from the internet and
30-
puts it in root directory. If dataset is already downloaded, it is not
31-
downloaded again.
3229
annotation_level (str, optional): The annotation level, supports ``variant``,
3330
``family`` and ``manufacturer``.
3431
transform (callable, optional): A function/transform that takes in an PIL image
3532
and returns a transformed version. E.g, ``transforms.RandomCrop``
3633
target_transform (callable, optional): A function/transform that takes in the
3734
target and transforms it.
35+
download (bool, optional): If True, downloads the dataset from the internet and
36+
puts it in root directory. If dataset is already downloaded, it is not
37+
downloaded again.
3838
"""
3939

4040
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
@@ -43,10 +43,10 @@ def __init__(
4343
self,
4444
root: str,
4545
split: str = "trainval",
46-
download: bool = False,
4746
annotation_level: str = "variant",
4847
transform: Optional[Callable] = None,
4948
target_transform: Optional[Callable] = None,
49+
download: bool = False,
5050
) -> None:
5151
super().__init__(root, transform=transform, target_transform=target_transform)
5252
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))

torchvision/datasets/flowers102.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ class Flowers102(VisionDataset):
2424
Args:
2525
root (string): Root directory of the dataset.
2626
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
27-
download (bool, optional): If true, downloads the dataset from the internet and
28-
puts it in root directory. If dataset is already downloaded, it is not
29-
downloaded again.
3027
transform (callable, optional): A function/transform that takes in an PIL image and returns a
3128
transformed version. E.g, ``transforms.RandomCrop``.
3229
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
download (bool, optional): If true, downloads the dataset from the internet and
31+
puts it in root directory. If dataset is already downloaded, it is not
32+
downloaded again.
3333
"""
3434

3535
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
@@ -44,9 +44,9 @@ def __init__(
4444
self,
4545
root: str,
4646
split: str = "train",
47-
download: bool = True,
4847
transform: Optional[Callable] = None,
4948
target_transform: Optional[Callable] = None,
49+
download: bool = False,
5050
) -> None:
5151
super().__init__(root, transform=transform, target_transform=target_transform)
5252
self._split = verify_str_arg(split, "split", ("train", "val", "test"))

torchvision/datasets/food101.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class Food101(VisionDataset):
2424
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
2525
version. E.g, ``transforms.RandomCrop``.
2626
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27+
download (bool, optional): If True, downloads the dataset from the internet and
28+
puts it in root directory. If dataset is already downloaded, it is not
29+
downloaded again. Default is False.
2730
"""
2831

2932
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
@@ -33,9 +36,9 @@ def __init__(
3336
self,
3437
root: str,
3538
split: str = "train",
36-
download: bool = True,
3739
transform: Optional[Callable] = None,
3840
target_transform: Optional[Callable] = None,
41+
download: bool = False,
3942
) -> None:
4043
super().__init__(root, transform=transform, target_transform=target_transform)
4144
self._split = verify_str_arg(split, "split", ("train", "test"))

0 commit comments

Comments
 (0)