Skip to content

Commit d14c03b

Browse files
YosuaMichaeldatumbox
authored andcommitted
[fbsync] Adding the huge vision transformer from SWAG (#5721)
Summary: * Add vit_b_16_swag * Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input * Update the accuracy to the experiment result on torchvision model * Fix typo missing underscore * raise exception instead of torch._assert, add back publication year (accidentally deleted) * Add license information on meta and readme * Improve wording and fix typo for pretrained model license in readme * Add vit_l_16 weight * Update README.rst * Update the accuracy meta on vit_l_16_swag model to result from our experiment * Add vit_h_14_swag model * Add accuracy from experiments * Add to vit_h_16 model to hubconf.py * Add docs and expected pkl file for test * Remove legacy compatibility for ViT_H_14 model * Test vit_h_14 with smaller image_size to speedup the test (Note: this ignores all push blocking failures!) Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095649 fbshipit-source-id: 639dab0577088e18e1bcfa06fd1f01be20c3fd44 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent da3f89f commit d14c03b

File tree

5 files changed

+54
-0
lines changed

5 files changed

+54
-0
lines changed

docs/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor:
9292
vit_b_32 = models.vit_b_32()
9393
vit_l_16 = models.vit_l_16()
9494
vit_l_32 = models.vit_l_32()
95+
vit_h_14 = models.vit_h_14()
9596
convnext_tiny = models.convnext_tiny()
9697
convnext_small = models.convnext_small()
9798
convnext_base = models.convnext_base()
@@ -213,6 +214,7 @@ vit_b_16 81.072 95.318
213214
vit_b_32 75.912 92.466
214215
vit_l_16 79.662 94.638
215216
vit_l_32 76.972 93.070
217+
vit_h_14 88.552 98.694
216218
convnext_tiny 82.520 96.146
217219
convnext_small 83.616 96.650
218220
convnext_base 84.062 96.870
@@ -434,6 +436,7 @@ VisionTransformer
434436
vit_b_32
435437
vit_l_16
436438
vit_l_32
439+
vit_h_14
437440

438441
ConvNeXt
439442
--------

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,5 @@
6767
vit_b_32,
6868
vit_l_16,
6969
vit_l_32,
70+
vit_h_14,
7071
)
939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def _check_input_backprop(model, inputs):
280280
"rpn_pre_nms_top_n_test": 1000,
281281
"rpn_post_nms_top_n_test": 1000,
282282
},
283+
"vit_h_14": {
284+
"image_size": 56,
285+
"input_shape": (1, 3, 56, 56),
286+
},
283287
}
284288
# speeding up slow models:
285289
slow_models = [

torchvision/models/vision_transformer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
"ViT_B_32_Weights",
2121
"ViT_L_16_Weights",
2222
"ViT_L_32_Weights",
23+
"ViT_H_14_Weights",
2324
"vit_b_16",
2425
"vit_b_32",
2526
"vit_l_16",
2627
"vit_l_32",
28+
"vit_h_14",
2729
]
2830

2931

@@ -435,6 +437,27 @@ class ViT_L_32_Weights(WeightsEnum):
435437
DEFAULT = IMAGENET1K_V1
436438

437439

440+
class ViT_H_14_Weights(WeightsEnum):
441+
IMAGENET1K_SWAG_V1 = Weights(
442+
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
443+
transforms=partial(
444+
ImageClassification,
445+
crop_size=518,
446+
resize_size=518,
447+
interpolation=InterpolationMode.BICUBIC,
448+
),
449+
meta={
450+
**_COMMON_SWAG_META,
451+
"num_params": 633470440,
452+
"size": (518, 518),
453+
"min_size": (518, 518),
454+
"acc@1": 88.552,
455+
"acc@5": 98.694,
456+
},
457+
)
458+
DEFAULT = IMAGENET1K_SWAG_V1
459+
460+
438461
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
439462
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
440463
"""
@@ -531,6 +554,29 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
531554
)
532555

533556

557+
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
558+
"""
559+
Constructs a vit_h_14 architecture from
560+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
561+
562+
Args:
563+
weights (ViT_H_14_Weights, optional): The pretrained weights for the model
564+
progress (bool): If True, displays a progress bar of the download to stderr
565+
"""
566+
weights = ViT_H_14_Weights.verify(weights)
567+
568+
return _vision_transformer(
569+
patch_size=14,
570+
num_layers=32,
571+
num_heads=16,
572+
hidden_dim=1280,
573+
mlp_dim=5120,
574+
weights=weights,
575+
progress=progress,
576+
**kwargs,
577+
)
578+
579+
534580
def interpolate_embeddings(
535581
image_size: int,
536582
patch_size: int,

0 commit comments

Comments
 (0)