Skip to content

Commit 76dbebf

Browse files
authored
Merge branch 'main' into improved-affine
2 parents f75bfb1 + abdae5a commit 76dbebf

File tree

4 files changed

+36
-60
lines changed

4 files changed

+36
-60
lines changed

test/builtin_dataset_mocks.py

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tempfile
1313
import unittest.mock
1414
import xml.etree.ElementTree as ET
15-
from collections import defaultdict, Counter, UserDict
15+
from collections import defaultdict, Counter
1616

1717
import numpy as np
1818
import PIL.Image
@@ -34,35 +34,17 @@
3434
__all__ = ["DATASET_MOCKS", "parametrize_dataset_mocks"]
3535

3636

37-
class ResourceMock(datasets.utils.OnlineResource):
38-
def __init__(self, *, dataset_name, dataset_config, **kwargs):
39-
super().__init__(**kwargs)
40-
self.dataset_name = dataset_name
41-
self.dataset_config = dataset_config
42-
43-
def _download(self, _):
44-
raise pytest.UsageError(
45-
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
46-
f"but this file does not exist."
47-
)
48-
49-
5037
class DatasetMock:
51-
def __init__(self, name, mock_data_fn, *, configs=None):
38+
def __init__(self, name, mock_data_fn):
5239
self.dataset = find(name)
40+
self.info = self.dataset.info
41+
self.name = self.info.name
42+
5343
self.root = TEST_HOME / self.dataset.name
5444
self.mock_data_fn = mock_data_fn
55-
self.configs = configs or self.info._configs
45+
self.configs = self.info._configs
5646
self._cache = {}
5747

58-
@property
59-
def info(self):
60-
return self.dataset.info
61-
62-
@property
63-
def name(self):
64-
return self.info.name
65-
6648
def _parse_mock_data(self, config, mock_infos):
6749
if mock_infos is None:
6850
raise pytest.UsageError(
@@ -79,7 +61,7 @@ def _parse_mock_data(self, config, mock_infos):
7961
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
8062
)
8163

82-
for config_, mock_info in list(mock_infos.items()):
64+
for config_, mock_info in mock_infos.items():
8365
if config_ in self._cache:
8466
raise pytest.UsageError(
8567
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
@@ -103,7 +85,7 @@ def _parse_mock_data(self, config, mock_infos):
10385
return mock_infos
10486

10587
def _prepare_resources(self, config):
106-
with contextlib.suppress(KeyError):
88+
if config in self._cache:
10789
return self._cache[config]
10890

10991
self.root.mkdir(exist_ok=True)
@@ -146,8 +128,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
146128
for mock in dataset_mocks:
147129
if isinstance(mock, DatasetMock):
148130
mocks[mock.name] = mock
149-
elif isinstance(mock, collections.abc.Sequence):
150-
mocks.update({mock_.name: mock_ for mock_ in mock})
151131
elif isinstance(mock, collections.abc.Mapping):
152132
mocks.update(mock)
153133
else:
@@ -173,14 +153,13 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
173153
)
174154

175155

176-
class DatasetMocks(UserDict):
177-
def set_from_named_callable(self, fn):
178-
name = fn.__name__.replace("_", "-")
179-
self.data[name] = DatasetMock(name, fn)
180-
return fn
156+
DATASET_MOCKS = {}
181157

182158

183-
DATASET_MOCKS = DatasetMocks()
159+
def register_mock(fn):
160+
name = fn.__name__.replace("_", "-")
161+
DATASET_MOCKS[name] = DatasetMock(name, fn)
162+
return fn
184163

185164

186165
class MNISTMockData:
@@ -258,7 +237,7 @@ def generate(
258237
return num_samples
259238

260239

261-
@DATASET_MOCKS.set_from_named_callable
240+
@register_mock
262241
def mnist(info, root, config):
263242
train = config.split == "train"
264243
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
@@ -274,7 +253,7 @@ def mnist(info, root, config):
274253
DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
275254

276255

277-
@DATASET_MOCKS.set_from_named_callable
256+
@register_mock
278257
def emnist(info, root, _):
279258
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
280259
# labels in the data files. Thus, num_categories != len(categories) there.
@@ -303,7 +282,7 @@ def emnist(info, root, _):
303282
return mock_infos
304283

305284

306-
@DATASET_MOCKS.set_from_named_callable
285+
@register_mock
307286
def qmnist(info, root, config):
308287
num_categories = len(info.categories)
309288
if config.split == "train":
@@ -382,7 +361,7 @@ def generate(
382361
make_tar(root, name, folder, compression="gz")
383362

384363

385-
@DATASET_MOCKS.set_from_named_callable
364+
@register_mock
386365
def cifar10(info, root, config):
387366
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
388367
test_files = ["test_batch"]
@@ -400,7 +379,7 @@ def cifar10(info, root, config):
400379
return len(train_files if config.split == "train" else test_files)
401380

402381

403-
@DATASET_MOCKS.set_from_named_callable
382+
@register_mock
404383
def cifar100(info, root, config):
405384
train_files = ["train"]
406385
test_files = ["test"]
@@ -418,7 +397,7 @@ def cifar100(info, root, config):
418397
return len(train_files if config.split == "train" else test_files)
419398

420399

421-
@DATASET_MOCKS.set_from_named_callable
400+
@register_mock
422401
def caltech101(info, root, config):
423402
def create_ann_file(root, name):
424403
import scipy.io
@@ -468,7 +447,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
468447
return num_images_per_category * len(info.categories)
469448

470449

471-
@DATASET_MOCKS.set_from_named_callable
450+
@register_mock
472451
def caltech256(info, root, config):
473452
dir = root / "256_ObjectCategories"
474453
num_images_per_category = 2
@@ -488,7 +467,7 @@ def caltech256(info, root, config):
488467
return num_images_per_category * len(info.categories)
489468

490469

491-
@DATASET_MOCKS.set_from_named_callable
470+
@register_mock
492471
def imagenet(info, root, config):
493472
wnids = tuple(info.extra.wnid_to_category.keys())
494473
if config.split == "train":
@@ -643,7 +622,7 @@ def generate(
643622
return num_samples
644623

645624

646-
@DATASET_MOCKS.set_from_named_callable
625+
@register_mock
647626
def coco(info, root, config):
648627
return dict(
649628
zip(
@@ -722,13 +701,13 @@ def generate(cls, root):
722701
return num_samples_map
723702

724703

725-
@DATASET_MOCKS.set_from_named_callable
704+
@register_mock
726705
def sbd(info, root, _):
727706
num_samples_map = SBDMockData.generate(root)
728707
return {config: num_samples_map[config.split] for config in info._configs}
729708

730709

731-
@DATASET_MOCKS.set_from_named_callable
710+
@register_mock
732711
def semeion(info, root, config):
733712
num_samples = 3
734713

@@ -839,7 +818,7 @@ def generate(cls, root, *, year, trainval):
839818
return num_samples_map
840819

841820

842-
@DATASET_MOCKS.set_from_named_callable
821+
@register_mock
843822
def voc(info, root, config):
844823
trainval = config.split != "test"
845824
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
@@ -938,13 +917,13 @@ def generate(cls, root):
938917
return num_samples_map
939918

940919

941-
@DATASET_MOCKS.set_from_named_callable
920+
@register_mock
942921
def celeba(info, root, _):
943922
num_samples_map = CelebAMockData.generate(root)
944923
return {config: num_samples_map[config.split] for config in info._configs}
945924

946925

947-
@DATASET_MOCKS.set_from_named_callable
926+
@register_mock
948927
def dtd(info, root, _):
949928
data_folder = root / "dtd"
950929

@@ -992,7 +971,7 @@ def dtd(info, root, _):
992971
return num_samples_map
993972

994973

995-
@DATASET_MOCKS.set_from_named_callable
974+
@register_mock
996975
def fer2013(info, root, config):
997976
num_samples = 5 if config.split == "train" else 3
998977

@@ -1017,7 +996,7 @@ def fer2013(info, root, config):
1017996
return num_samples
1018997

1019998

1020-
@DATASET_MOCKS.set_from_named_callable
999+
@register_mock
10211000
def gtsrb(info, root, config):
10221001
num_examples_per_class = 5 if config.split == "train" else 3
10231002
classes = ("00000", "00042", "00012")
@@ -1087,7 +1066,7 @@ def _make_ann_file(path, num_examples, class_idx):
10871066
return num_examples
10881067

10891068

1090-
@DATASET_MOCKS.set_from_named_callable
1069+
@register_mock
10911070
def clevr(info, root, config):
10921071
data_folder = root / "CLEVR_v1.0"
10931072

@@ -1193,7 +1172,7 @@ def generate(self, root):
11931172
return num_samples_map
11941173

11951174

1196-
@DATASET_MOCKS.set_from_named_callable
1175+
@register_mock
11971176
def oxford_iiit_pet(info, root, config):
11981177
num_samples_map = OxfordIIITPetMockData.generate(root)
11991178
return {config_: num_samples_map[config_.split] for config_ in info._configs}
@@ -1360,13 +1339,13 @@ def generate(cls, root):
13601339
return num_samples_map
13611340

13621341

1363-
@DATASET_MOCKS.set_from_named_callable
1342+
@register_mock
13641343
def cub200(info, root, config):
13651344
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
13661345
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
13671346

13681347

1369-
@DATASET_MOCKS.set_from_named_callable
1348+
@register_mock
13701349
def svhn(info, root, config):
13711350
import scipy.io as sio
13721351

test/test_prototype_builtin_datasets.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_traversable(self, dataset_mock, config):
110110
)
111111
},
112112
)
113-
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
113+
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
114114
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
115115
def scan(graph):
116116
for node, sub_graph in graph.items():
@@ -120,10 +120,7 @@ def scan(graph):
120120
with dataset_mock.prepare(config):
121121
dataset = datasets.load(dataset_mock.name, **config)
122122

123-
for dp in scan(traverse(dataset)):
124-
if type(dp) is annotation_dp_type:
125-
break
126-
else:
123+
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
127124
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
128125

129126

torchvision/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .deform_conv import deform_conv2d, DeformConv2d
1414
from .feature_pyramid_network import FeaturePyramidNetwork
1515
from .focal_loss import sigmoid_focal_loss
16-
from .generalized_box_iou_loss import generalized_box_iou_loss
16+
from .giou_loss import generalized_box_iou_loss
1717
from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation
1818
from .poolers import MultiScaleRoIAlign
1919
from .ps_roi_align import ps_roi_align, PSRoIAlign

0 commit comments

Comments
 (0)