Skip to content

Commit 781b0f9

Browse files
Add SWAG Vision Transformer Weight (#5714)
* Add vit_b_16_swag * Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input * Update the accuracy to the experiment result on torchvision model * Fix typo missing underscore * raise exception instead of torch._assert, add back publication year (accidentally deleted) * Add license information on meta and readme * Improve wording and fix typo for pretrained model license in readme * Add vit_l_16 weight * Update README.rst Co-authored-by: Vasilis Vryniotis <[email protected]> * Update the accuracy meta on vit_l_16_swag model to result from our experiment Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 3925946 commit 781b0f9

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,10 @@ Disclaimer on Datasets
185185
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
186186

187187
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
188+
189+
Pre-trained Model License
190+
=========================
191+
192+
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
193+
194+
More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE <https://github.com/facebookresearch/SWAG/blob/main/LICENSE>`_ for additional details.

test/test_extended_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn):
115115
incorrect_params.append(w)
116116
else:
117117
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
118-
incorrect_params.append(w)
118+
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
119+
incorrect_params.append(w)
119120
if not w.name.isupper():
120121
bad_names.append(w)
121122

torchvision/models/vision_transformer.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, List, NamedTuple, Optional
4+
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
55

66
import torch
77
import torch.nn as nn
@@ -284,10 +284,21 @@ def _vision_transformer(
284284
progress: bool,
285285
**kwargs: Any,
286286
) -> VisionTransformer:
287-
image_size = kwargs.pop("image_size", 224)
288-
289287
if weights is not None:
290288
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
289+
if isinstance(weights.meta["size"], int):
290+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
291+
elif isinstance(weights.meta["size"], Sequence):
292+
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
293+
raise ValueError(
294+
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
295+
)
296+
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
297+
else:
298+
raise ValueError(
299+
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
300+
)
301+
image_size = kwargs.pop("image_size", 224)
291302

292303
model = VisionTransformer(
293304
image_size=image_size,
@@ -313,6 +324,14 @@ def _vision_transformer(
313324
"interpolation": InterpolationMode.BILINEAR,
314325
}
315326

327+
_COMMON_SWAG_META = {
328+
**_COMMON_META,
329+
"publication_year": 2022,
330+
"recipe": "https://github.com/facebookresearch/SWAG",
331+
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
332+
"interpolation": InterpolationMode.BICUBIC,
333+
}
334+
316335

317336
class ViT_B_16_Weights(WeightsEnum):
318337
IMAGENET1K_V1 = Weights(
@@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum):
328347
"acc@5": 95.318,
329348
},
330349
)
350+
IMAGENET1K_SWAG_V1 = Weights(
351+
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
352+
transforms=partial(
353+
ImageClassification,
354+
crop_size=384,
355+
resize_size=384,
356+
interpolation=InterpolationMode.BICUBIC,
357+
),
358+
meta={
359+
**_COMMON_SWAG_META,
360+
"num_params": 86859496,
361+
"size": (384, 384),
362+
"min_size": (384, 384),
363+
"acc@1": 85.304,
364+
"acc@5": 97.650,
365+
},
366+
)
331367
DEFAULT = IMAGENET1K_V1
332368

333369

@@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum):
362398
"acc@5": 94.638,
363399
},
364400
)
401+
IMAGENET1K_SWAG_V1 = Weights(
402+
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
403+
transforms=partial(
404+
ImageClassification,
405+
crop_size=512,
406+
resize_size=512,
407+
interpolation=InterpolationMode.BICUBIC,
408+
),
409+
meta={
410+
**_COMMON_SWAG_META,
411+
"num_params": 305174504,
412+
"size": (512, 512),
413+
"min_size": (512, 512),
414+
"acc@1": 88.064,
415+
"acc@5": 98.512,
416+
},
417+
)
365418
DEFAULT = IMAGENET1K_V1
366419

367420

0 commit comments

Comments
 (0)