Skip to content

Add SWAG Vision Transformer Weight #5714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,10 @@ Disclaimer on Datasets
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.

If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!

Pre-trained Model License
=========================

The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.

More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE <https://github.com/facebookresearch/SWAG/blob/main/LICENSE>`_ for additional details.
3 changes: 2 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)

Expand Down
59 changes: 56 additions & 3 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional
from typing import Any, Callable, List, NamedTuple, Optional, Sequence

import torch
import torch.nn as nn
Expand Down Expand Up @@ -284,10 +284,21 @@ def _vision_transformer(
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if isinstance(weights.meta["size"], int):
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
elif isinstance(weights.meta["size"], Sequence):
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
raise ValueError(
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
)
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
else:
raise ValueError(
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
)
image_size = kwargs.pop("image_size", 224)

model = VisionTransformer(
image_size=image_size,
Expand All @@ -313,6 +324,14 @@ def _vision_transformer(
"interpolation": InterpolationMode.BILINEAR,
}

_COMMON_SWAG_META = {
**_COMMON_META,
"publication_year": 2022,
"recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
"interpolation": InterpolationMode.BICUBIC,
}


class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
Expand All @@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318,
},
)
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial(
ImageClassification,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 86859496,
"size": (384, 384),
"min_size": (384, 384),
"acc@1": 85.304,
"acc@5": 97.650,
},
)
DEFAULT = IMAGENET1K_V1


Expand Down Expand Up @@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638,
},
)
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
transforms=partial(
ImageClassification,
crop_size=512,
resize_size=512,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 305174504,
"size": (512, 512),
"min_size": (512, 512),
"acc@1": 88.064,
"acc@5": 98.512,
},
)
DEFAULT = IMAGENET1K_V1


Expand Down