Skip to content

Commit ef8b48d

Browse files
committed
Vision transformer with tests enabled.
1 parent 0caf745 commit ef8b48d

File tree

3 files changed

+337
-214
lines changed

3 files changed

+337
-214
lines changed

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .shufflenetv2 import *
1111
from .efficientnet import *
1212
from .regnet import *
13+
from .vision_transformer import *
1314
from . import detection
1415
from . import feature_extraction
1516
from . import quantization
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# References:
2+
# https://github.com/google-research/vision_transformer
3+
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py
4+
5+
6+
import math
7+
from collections import OrderedDict
8+
from functools import partial
9+
from typing import Any, Callable, Optional
10+
11+
import torch
12+
import torch.nn as nn
13+
from torch import Tensor
14+
15+
16+
__all__ = [
17+
"VisionTransformer",
18+
"vit_b_16",
19+
"vit_b_32",
20+
"vit_l_16",
21+
"vit_l_32",
22+
]
23+
24+
25+
class MLPBlock(nn.Sequential):
26+
"""Transformer MLP block."""
27+
28+
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
29+
super().__init__()
30+
self.linear_1 = nn.Linear(in_dim, mlp_dim)
31+
self.act = nn.GELU()
32+
self.dropout_1 = nn.Dropout(dropout)
33+
self.linear_2 = nn.Linear(mlp_dim, in_dim)
34+
self.dropout_2 = nn.Dropout(dropout)
35+
self._init_weights()
36+
37+
def _init_weights(self):
38+
nn.init.xavier_uniform_(self.linear_1.weight)
39+
nn.init.xavier_uniform_(self.linear_2.weight)
40+
nn.init.normal_(self.linear_1.bias, std=1e-6)
41+
nn.init.normal_(self.linear_2.bias, std=1e-6)
42+
43+
44+
class EncoderBlock(nn.Module):
45+
"""Transformer encoder block."""
46+
47+
def __init__(
48+
self,
49+
num_heads: int,
50+
hidden_dim: int,
51+
mlp_dim: int,
52+
dropout: float,
53+
attention_dropout: float,
54+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
55+
):
56+
super().__init__()
57+
self.num_heads = num_heads
58+
59+
# Attention block
60+
self.ln_1 = norm_layer(hidden_dim)
61+
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
62+
self.dropout = nn.Dropout(dropout)
63+
64+
# MLP block
65+
self.ln_2 = norm_layer(hidden_dim)
66+
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
67+
68+
def forward(self, input: Tensor):
69+
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
70+
x = self.ln_1(input)
71+
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
72+
x = self.dropout(x)
73+
x = x + input
74+
75+
y = self.ln_2(x)
76+
y = self.mlp(y)
77+
return x + y
78+
79+
80+
class Encoder(nn.Module):
81+
"""Transformer Model Encoder for sequence to sequence translation."""
82+
83+
def __init__(
84+
self,
85+
seq_length: int,
86+
num_layers: int,
87+
num_heads: int,
88+
hidden_dim: int,
89+
mlp_dim: int,
90+
dropout: float,
91+
attention_dropout: float,
92+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
93+
):
94+
super().__init__()
95+
# Note that batch_size is on the first dim because
96+
# we have batch_first=True in nn.MultiAttention() by default
97+
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
98+
self.dropout = nn.Dropout(dropout)
99+
layers: OrderedDict[str, nn.Module] = OrderedDict()
100+
for i in range(num_layers):
101+
layers[f"encoder_layer_{i}"] = EncoderBlock(
102+
num_heads,
103+
hidden_dim,
104+
mlp_dim,
105+
dropout,
106+
attention_dropout,
107+
norm_layer,
108+
)
109+
self.layers = nn.Sequential(layers)
110+
self.ln = norm_layer(hidden_dim)
111+
112+
def forward(self, input: Tensor):
113+
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
114+
input = input + self.pos_embedding
115+
return self.ln(self.layers(self.dropout(input)))
116+
117+
118+
class VisionTransformer(nn.Module):
119+
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
120+
121+
def __init__(
122+
self,
123+
image_size: int,
124+
patch_size: int,
125+
num_layers: int,
126+
num_heads: int,
127+
hidden_dim: int,
128+
mlp_dim: int,
129+
dropout: float = 0.0,
130+
attention_dropout: float = 0.0,
131+
num_classes: int = 1000,
132+
representation_size: Optional[int] = None,
133+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
134+
):
135+
super().__init__()
136+
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
137+
self.image_size = image_size
138+
self.patch_size = patch_size
139+
self.hidden_dim = hidden_dim
140+
self.mlp_dim = mlp_dim
141+
self.attention_dropout = attention_dropout
142+
self.dropout = dropout
143+
self.num_classes = num_classes
144+
self.representation_size = representation_size
145+
self.norm_layer = norm_layer
146+
147+
input_channels = 3
148+
149+
# The conv_proj is a more efficient version of reshaping, permuting
150+
# and projecting the input
151+
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
152+
153+
seq_length = (image_size // patch_size) ** 2
154+
155+
# Add a class token
156+
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
157+
seq_length += 1
158+
159+
self.encoder = Encoder(
160+
seq_length,
161+
num_layers,
162+
num_heads,
163+
hidden_dim,
164+
mlp_dim,
165+
dropout,
166+
attention_dropout,
167+
norm_layer,
168+
)
169+
self.seq_length = seq_length
170+
171+
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
172+
if representation_size is None:
173+
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
174+
else:
175+
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
176+
heads_layers["act"] = nn.Tanh()
177+
heads_layers["head"] = nn.Linear(representation_size, num_classes)
178+
179+
self.heads = nn.Sequential(heads_layers)
180+
self._init_weights()
181+
182+
def _init_weights(self):
183+
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
184+
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
185+
nn.init.zeros_(self.conv_proj.bias)
186+
187+
if hasattr(self.heads, "pre_logits"):
188+
fan_in = self.heads.pre_logits.in_features
189+
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
190+
nn.init.zeros_(self.heads.pre_logits.bias)
191+
192+
nn.init.zeros_(self.heads.head.weight)
193+
nn.init.zeros_(self.heads.head.bias)
194+
195+
def forward(self, x: torch.Tensor):
196+
n, c, h, w = x.shape
197+
p = self.patch_size
198+
torch._assert(h == self.image_size, "Wrong image height!")
199+
torch._assert(w == self.image_size, "Wrong image width!")
200+
n_h = h // p
201+
n_w = w // p
202+
203+
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
204+
x = self.conv_proj(x)
205+
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
206+
x = x.reshape(n, self.hidden_dim, n_h * n_w)
207+
208+
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
209+
# The self attention layer expects inputs in the format (N, S, E)
210+
# where S is the source sequence length, N is the batch size, E is the
211+
# embedding dimension
212+
x = x.permute(0, 2, 1)
213+
214+
# Expand the class token to the full batch.
215+
batch_class_token = self.class_token.expand(n, -1, -1)
216+
x = torch.cat([batch_class_token, x], dim=1)
217+
218+
x = self.encoder(x)
219+
220+
# Classifier "token" as used by standard language architectures
221+
x = x[:, 0]
222+
223+
x = self.heads(x)
224+
225+
return x
226+
227+
228+
def _vision_transformer(
229+
patch_size: int,
230+
num_layers: int,
231+
num_heads: int,
232+
hidden_dim: int,
233+
mlp_dim: int,
234+
pretrained: bool,
235+
progress: bool,
236+
**kwargs: Any,
237+
) -> VisionTransformer:
238+
image_size = kwargs.pop("image_size", 224)
239+
240+
model = VisionTransformer(
241+
image_size=image_size,
242+
patch_size=patch_size,
243+
num_layers=num_layers,
244+
num_heads=num_heads,
245+
hidden_dim=hidden_dim,
246+
mlp_dim=mlp_dim,
247+
**kwargs,
248+
)
249+
250+
if pretrained:
251+
raise Exception("Weights not available") # TODO: Adding pre-trained models
252+
253+
return model
254+
255+
256+
def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
257+
"""
258+
Constructs a ViT_b_16 architecture from
259+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
260+
Args:
261+
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
262+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
263+
"""
264+
return _vision_transformer(
265+
patch_size=16,
266+
num_layers=12,
267+
num_heads=12,
268+
hidden_dim=768,
269+
mlp_dim=3072,
270+
pretrained=pretrained,
271+
progress=progress,
272+
**kwargs,
273+
)
274+
275+
276+
def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
277+
"""
278+
Constructs a ViT_b_32 architecture from
279+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
280+
Args:
281+
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
282+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
283+
"""
284+
return _vision_transformer(
285+
patch_size=32,
286+
num_layers=12,
287+
num_heads=12,
288+
hidden_dim=768,
289+
mlp_dim=3072,
290+
pretrained=pretrained,
291+
progress=progress,
292+
**kwargs,
293+
)
294+
295+
296+
def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
297+
"""
298+
Constructs a ViT_l_16 architecture from
299+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
300+
Args:
301+
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
302+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
303+
"""
304+
return _vision_transformer(
305+
patch_size=16,
306+
num_layers=24,
307+
num_heads=16,
308+
hidden_dim=1024,
309+
mlp_dim=4096,
310+
pretrained=pretrained,
311+
progress=progress,
312+
**kwargs,
313+
)
314+
315+
316+
def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
317+
"""
318+
Constructs a ViT_l_32 architecture from
319+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
320+
Args:
321+
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet. Default: False.
322+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
323+
"""
324+
return _vision_transformer(
325+
patch_size=32,
326+
num_layers=24,
327+
num_heads=16,
328+
hidden_dim=1024,
329+
mlp_dim=4096,
330+
pretrained=pretrained,
331+
progress=progress,
332+
**kwargs,
333+
)

0 commit comments

Comments
 (0)