Skip to content

Commit 6840a7d

Browse files
authored
Merge branch 'main' into main
2 parents 6bae6c1 + fe65d37 commit 6840a7d

File tree

14 files changed

+486
-86
lines changed

14 files changed

+486
-86
lines changed

docs/source/models.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ architectures for image classification:
4141
- `EfficientNet`_
4242
- `RegNet`_
4343
- `VisionTransformer`_
44+
- `ConvNeXt`_
4445

4546
You can construct a model with random weights by calling its constructor:
4647

@@ -88,7 +89,7 @@ You can construct a model with random weights by calling its constructor:
8889
vit_b_32 = models.vit_b_32()
8990
vit_l_16 = models.vit_l_16()
9091
vit_l_32 = models.vit_l_32()
91-
vit_h_14 = models.vit_h_14()
92+
vit_h_14 = models.vit_h_14()
9293
9394
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9495
These can be constructed by passing ``pretrained=True``:
@@ -248,6 +249,7 @@ vit_b_16 81.072 95.318
248249
vit_b_32 75.912 92.466
249250
vit_l_16 79.662 94.638
250251
vit_l_32 76.972 93.070
252+
convnext_tiny (prototype) 82.520 96.146
251253
================================ ============= =============
252254

253255

@@ -266,6 +268,7 @@ vit_l_32 76.972 93.070
266268
.. _EfficientNet: https://arxiv.org/abs/1905.11946
267269
.. _RegNet: https://arxiv.org/abs/2003.13678
268270
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
271+
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
269272

270273
.. currentmodule:: torchvision.models
271274

references/classification/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,20 @@ Note that the above command corresponds to training on a single node with 8 GPUs
197197
For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs),
198198
and `--batch_size 64`.
199199

200+
201+
### ConvNeXt
202+
```
203+
torchrun --nproc_per_node=8 train.py\
204+
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
205+
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
206+
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
207+
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4
208+
```
209+
210+
Note that the above command corresponds to training on a single node with 8 GPUs.
211+
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
212+
and `--batch_size 64`.
213+
200214
## Mixed precision training
201215
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
202216

test/builtin_dataset_mocks.py

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pickle
1111
import random
1212
import tempfile
13+
import unittest.mock
1314
import xml.etree.ElementTree as ET
1415
from collections import defaultdict, Counter, UserDict
1516

@@ -21,7 +22,8 @@
2122
from torch.nn.functional import one_hot
2223
from torch.testing import make_tensor as _make_tensor
2324
from torchvision.prototype import datasets
24-
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
25+
from torchvision.prototype.datasets._api import find
26+
from torchvision.prototype.utils._internal import sequence_to_str
2527

2628
make_tensor = functools.partial(_make_tensor, device="cpu")
2729
make_scalar = functools.partial(make_tensor, ())
@@ -49,7 +51,7 @@ class DatasetMock:
4951
def __init__(self, name, mock_data_fn, *, configs=None):
5052
self.dataset = find(name)
5153
self.root = TEST_HOME / self.dataset.name
52-
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
54+
self.mock_data_fn = mock_data_fn
5355
self.configs = configs or self.info._configs
5456
self._cache = {}
5557

@@ -61,77 +63,71 @@ def info(self):
6163
def name(self):
6264
return self.info.name
6365

64-
def _parse_mock_data(self, mock_data_fn):
65-
def wrapper(info, root, config):
66-
mock_infos = mock_data_fn(info, root, config)
66+
def _parse_mock_data(self, config, mock_infos):
67+
if mock_infos is None:
68+
raise pytest.UsageError(
69+
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
70+
f"integer indicating the number of samples for the current `config`."
71+
)
72+
73+
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
74+
if datasets.utils.DatasetConfig not in key_types:
75+
mock_infos = {config: mock_infos}
76+
elif len(key_types) > 1:
77+
raise pytest.UsageError(
78+
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
79+
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
80+
)
6781

68-
if mock_infos is None:
82+
for config_, mock_info in list(mock_infos.items()):
83+
if config_ in self._cache:
6984
raise pytest.UsageError(
70-
f"The mock data function for dataset '{self.name}' returned nothing. It needs to at least return an "
71-
f"integer indicating the number of samples for the current `config`."
85+
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
86+
f"already exists in the cache."
7287
)
73-
74-
key_types = set(type(key) for key in mock_infos) if isinstance(mock_infos, dict) else {}
75-
if datasets.utils.DatasetConfig not in key_types:
76-
mock_infos = {config: mock_infos}
77-
elif len(key_types) > 1:
88+
if isinstance(mock_info, int):
89+
mock_infos[config_] = dict(num_samples=mock_info)
90+
elif not isinstance(mock_info, dict):
7891
raise pytest.UsageError(
79-
f"Unable to handle the returned dictionary of the mock data function for dataset {self.name}. If "
80-
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
92+
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
93+
f"{config_}. The returned object should be a dictionary containing at least the number of "
94+
f"samples for the key `'num_samples'`. If no additional information is required for specific "
95+
f"tests, the number of samples can also be returned as an integer."
96+
)
97+
elif "num_samples" not in mock_info:
98+
raise pytest.UsageError(
99+
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
100+
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
81101
)
82102

83-
for config_, mock_info in list(mock_infos.items()):
84-
if config_ in self._cache:
85-
raise pytest.UsageError(
86-
f"The mock info for config {config_} of dataset {self.name} generated for config {config} "
87-
f"already exists in the cache."
88-
)
89-
if isinstance(mock_info, int):
90-
mock_infos[config_] = dict(num_samples=mock_info)
91-
elif not isinstance(mock_info, dict):
92-
raise pytest.UsageError(
93-
f"The mock data function for dataset '{self.name}' returned a {type(mock_infos)} for `config` "
94-
f"{config_}. The returned object should be a dictionary containing at least the number of "
95-
f"samples for the key `'num_samples'`. If no additional information is required for specific "
96-
f"tests, the number of samples can also be returned as an integer."
97-
)
98-
elif "num_samples" not in mock_info:
99-
raise pytest.UsageError(
100-
f"The dictionary returned by the mock data function for dataset '{self.name}' and config "
101-
f"{config_} has to contain a `'num_samples'` entry indicating the number of samples."
102-
)
103-
104-
return mock_infos
105-
106-
return wrapper
103+
return mock_infos
107104

108-
def _load_mock(self, config):
105+
def _prepare_resources(self, config):
109106
with contextlib.suppress(KeyError):
110107
return self._cache[config]
111108

112109
self.root.mkdir(exist_ok=True)
113-
for config_, mock_info in self.mock_data_fn(self.info, self.root, config).items():
114-
mock_resources = [
115-
ResourceMock(dataset_name=self.name, dataset_config=config_, file_name=resource.file_name)
116-
for resource in self.dataset.resources(config_)
117-
]
118-
self._cache[config_] = (mock_resources, mock_info)
110+
mock_infos = self._parse_mock_data(config, self.mock_data_fn(self.info, self.root, config))
111+
112+
available_file_names = {path.name for path in self.root.glob("*")}
113+
for config_, mock_info in mock_infos.items():
114+
required_file_names = {resource.file_name for resource in self.dataset.resources(config_)}
115+
missing_file_names = required_file_names - available_file_names
116+
if missing_file_names:
117+
raise pytest.UsageError(
118+
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
119+
f"for {config_}, but they were not created by the mock data function."
120+
)
121+
122+
self._cache[config_] = mock_info
119123

120124
return self._cache[config]
121125

122-
def load(self, config, *, decoder=DEFAULT_DECODER):
123-
try:
124-
self.info.check_dependencies()
125-
except ModuleNotFoundError as error:
126-
pytest.skip(str(error))
127-
128-
mock_resources, mock_info = self._load_mock(config)
129-
datapipe = self.dataset._make_datapipe(
130-
[resource.load(self.root) for resource in mock_resources],
131-
config=config,
132-
decoder=DEFAULT_DECODER_MAP.get(self.info.type) if decoder is DEFAULT_DECODER else decoder,
133-
)
134-
return datapipe, mock_info
126+
@contextlib.contextmanager
127+
def prepare(self, config):
128+
mock_info = self._prepare_resources(config)
129+
with unittest.mock.patch("torchvision.prototype.datasets._api.home", return_value=str(TEST_HOME)):
130+
yield mock_info
135131

136132

137133
def config_id(name, config):
@@ -1000,7 +996,7 @@ def dtd(info, root, _):
1000996
def fer2013(info, root, config):
1001997
num_samples = 5 if config.split == "train" else 3
1002998

1003-
path = root / f"{config.split}.txt"
999+
path = root / f"{config.split}.csv"
10041000
with open(path, "w", newline="") as file:
10051001
field_names = ["emotion"] if config.split == "train" else []
10061002
field_names.append("pixels")
@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
10611057
file,
10621058
)
10631059

1064-
make_zip(root, f"{data_folder.name}.zip")
1060+
make_zip(root, f"{data_folder.name}.zip", data_folder)
10651061

10661062
return {config_: num_samples_map[config_.split] for config_ in info._configs}
10671063

@@ -1121,8 +1117,8 @@ def generate(self, root):
11211117
for path in segmentation_files:
11221118
path.with_name(f".{path.name}").touch()
11231119

1124-
make_tar(root, "images.tar")
1125-
make_tar(root, anns_folder.with_suffix(".tar").name)
1120+
make_tar(root, "images.tar.gz", compression="gz")
1121+
make_tar(root, anns_folder.with_suffix(".tar.gz").name, compression="gz")
11261122

11271123
return num_samples_map
11281124

@@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files):
12111207
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
12121208
)
12131209

1214-
make_tar(root, segmentations_folder.with_suffix(".tgz").name)
1210+
make_tar(root, segmentations_folder.with_suffix(".tgz").name, compression="gz")
12151211

12161212
@classmethod
12171213
def generate(cls, root):
@@ -1298,3 +1294,23 @@ def generate(cls, root):
12981294
def cub200(info, root, config):
12991295
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
13001296
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
1297+
1298+
1299+
@DATASET_MOCKS.set_from_named_callable
1300+
def svhn(info, root, config):
1301+
import scipy.io as sio
1302+
1303+
num_samples = {
1304+
"train": 2,
1305+
"test": 3,
1306+
"extra": 4,
1307+
}[config.split]
1308+
1309+
sio.savemat(
1310+
root / f"{config.split}_32x32.mat",
1311+
{
1312+
"X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8),
1313+
"y": np.random.randint(10, size=(num_samples,), dtype=np.uint8),
1314+
},
1315+
)
1316+
return num_samples

test/datasets_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
868868
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
869869
archive = pathlib.Path(root) / name
870870
if not files_or_dirs:
871-
dir = archive.with_suffix("")
872-
if dir.exists() and dir.is_dir():
873-
files_or_dirs = (dir,)
871+
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
872+
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
873+
file_or_dir = archive
874+
for _ in range(len(archive.suffixes)):
875+
file_or_dir = file_or_dir.with_suffix("")
876+
if file_or_dir.exists():
877+
files_or_dirs = (file_or_dir,)
874878
else:
875879
raise ValueError("No file or dir provided.")
876880

Binary file not shown.

test/test_datasets_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import contextlib
22
import gzip
33
import os
4+
import pathlib
5+
import re
46
import tarfile
57
import zipfile
68

79
import pytest
810
import torchvision.datasets.utils as utils
911
from torch._utils_internal import get_file_path_2
12+
from torchvision.datasets.folder import make_dataset
1013
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
1114

12-
1315
TEST_FILE = get_file_path_2(
1416
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
1517
)
@@ -214,5 +216,29 @@ def test_verify_str_arg(self):
214216
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
215217

216218

219+
@pytest.mark.parametrize(
220+
("kwargs", "expected_error_msg"),
221+
[
222+
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
223+
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
224+
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
225+
],
226+
)
227+
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
228+
tmpdir = pathlib.Path(tmpdir)
229+
230+
(tmpdir / "a").mkdir()
231+
(tmpdir / "a" / "a.png").touch()
232+
233+
(tmpdir / "b").mkdir()
234+
(tmpdir / "b" / "b.jpeg").touch()
235+
236+
(tmpdir / "c").mkdir()
237+
(tmpdir / "c" / "c.unknown").touch()
238+
239+
with pytest.raises(FileNotFoundError, match=expected_error_msg):
240+
make_dataset(str(tmpdir), **kwargs)
241+
242+
217243
if __name__ == "__main__":
218244
pytest.main([__file__])

0 commit comments

Comments
 (0)