From 5dcfeda8a32940d7b642207d61e121b2d12282ee Mon Sep 17 00:00:00 2001 From: sallysyw Date: Wed, 19 Jan 2022 02:54:33 +0000 Subject: [PATCH 1/4] adding vit_h_14 --- .../ModelTester.test_vit_h_14_expect.pkl | Bin 0 -> 939 bytes torchvision/models/vision_transformer.py | 21 ++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 test/expect/ModelTester.test_vit_h_14_expect.pkl diff --git a/test/expect/ModelTester.test_vit_h_14_expect.pkl b/test/expect/ModelTester.test_vit_h_14_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1f846beb6a0bccf8b545f5a67b74482015cc878b GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5(jATumGUYDqB%_@&wQ~AdEY-_!+F>p;eYzR1Ay-Hz#ul>i!MRpZGie3qz3t@Vp zVG!WW#-;;RB*&}^R}M ) +def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_h_14 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + NOTE: Pretrained weights are not available for this model. + """ + return _vision_transformer( + arch="vit_h_14", + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + def interpolate_embeddings( image_size: int, patch_size: int, From cb7e01d75c0fdfd06173799c95c96220f155f2fd Mon Sep 17 00:00:00 2001 From: sallysyw Date: Wed, 19 Jan 2022 03:46:36 +0000 Subject: [PATCH 2/4] prototype and docs --- docs/source/models.rst | 2 ++ hubconf.py | 1 + .../prototype/models/vision_transformer.py | 22 +++++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 62c104cf927..f1331d5baa9 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -88,6 +88,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + vit_h_14 = models.vit_h_14() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -460,6 +461,7 @@ VisionTransformer vit_b_32 vit_l_16 vit_l_32 + vit_h_14 Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 2b2eeb1c166..1b3b191efa4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -63,4 +63,5 @@ vit_b_32, vit_l_16, vit_l_32, + vit_h_14, ) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 1cd186a2d82..56a38cdcd77 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -23,6 +23,7 @@ "vit_b_32", "vit_l_16", "vit_l_32", + "vit_h_14", ] @@ -99,6 +100,11 @@ class ViT_L_32_Weights(WeightsEnum): default = ImageNet1K_V1 +class ViT_H_14_Weights(WeightsEnum): + # Weights are not available yet. + pass + + def _vision_transformer( patch_size: int, num_layers: int, @@ -192,3 +198,19 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru progress=progress, **kwargs, ) + + +@handle_legacy_interface(weights=("pretrained", None)) +def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + weights = ViT_H_14_Weights.verify(weights) + + return _vision_transformer( + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights, + progress=progress, + **kwargs, + ) From 25c5b0e79a1d71164d384eb4462c4886cd7441cc Mon Sep 17 00:00:00 2001 From: sallysyw Date: Wed, 19 Jan 2022 04:03:47 +0000 Subject: [PATCH 3/4] bug fix --- torchvision/prototype/models/vision_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 56a38cdcd77..ed8c53e8e27 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -205,11 +205,11 @@ def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = Tru weights = ViT_H_14_Weights.verify(weights) return _vision_transformer( - patch_size=32, - num_layers=24, + patch_size=14, + num_layers=32, num_heads=16, - hidden_dim=1024, - mlp_dim=4096, + hidden_dim=1280, + mlp_dim=5120, weights=weights, progress=progress, **kwargs, From e166e566d6ac51aef704ff5f71be922ec6899549 Mon Sep 17 00:00:00 2001 From: sallysyw Date: Wed, 19 Jan 2022 06:08:07 +0000 Subject: [PATCH 4/4] adding curl check --- torchvision/models/vision_transformer.py | 2 ++ torchvision/prototype/models/vision_transformer.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index c9690bcceaf..a64f342e1a0 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -261,6 +261,8 @@ def _vision_transformer( ) if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type '{arch}'!") state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index ed8c53e8e27..af742c7ee01 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -19,6 +19,7 @@ "ViT_B_32_Weights", "ViT_L_16_Weights", "ViT_L_32_Weights", + "ViT_H_14_Weights", "vit_b_16", "vit_b_32", "vit_l_16",