Skip to content

Commit 6a3342c

Browse files
committed
dd factory kwargs added to a bunch of vit/vit-hybrids. cait, coat, convit, convmixer, deit, mvitv2, nest, pit, pvt_v2, tiny_vit, tnt, twins, visformer, xcit
1 parent 068e6d4 commit 6a3342c

File tree

15 files changed

+1343
-847
lines changed

15 files changed

+1343
-847
lines changed

timm/models/cait.py

Lines changed: 110 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Copyright (c) 2015-present, Facebook, Inc.
1010
# All rights reserved.
1111
from functools import partial
12-
from typing import List, Optional, Tuple, Union
12+
from typing import List, Optional, Tuple, Union, Type, Any
1313

1414
import torch
1515
import torch.nn as nn
@@ -29,18 +29,28 @@ class ClassAttn(nn.Module):
2929
# with slight modifications to do CA
3030
fused_attn: torch.jit.Final[bool]
3131

32-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
32+
def __init__(
33+
self,
34+
dim: int,
35+
num_heads: int = 8,
36+
qkv_bias: bool = False,
37+
attn_drop: float = 0.,
38+
proj_drop: float = 0.,
39+
device=None,
40+
dtype=None,
41+
):
3342
super().__init__()
43+
dd = {'device': device, 'dtype': dtype}
3444
self.num_heads = num_heads
3545
head_dim = dim // num_heads
3646
self.scale = head_dim ** -0.5
3747
self.fused_attn = use_fused_attn()
3848

39-
self.q = nn.Linear(dim, dim, bias=qkv_bias)
40-
self.k = nn.Linear(dim, dim, bias=qkv_bias)
41-
self.v = nn.Linear(dim, dim, bias=qkv_bias)
49+
self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
50+
self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
51+
self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
4252
self.attn_drop = nn.Dropout(attn_drop)
43-
self.proj = nn.Linear(dim, dim)
53+
self.proj = nn.Linear(dim, dim, **dd)
4454
self.proj_drop = nn.Dropout(proj_drop)
4555

4656
def forward(self, x):
@@ -73,39 +83,44 @@ class LayerScaleBlockClassAttn(nn.Module):
7383
# with slight modifications to add CA and LayerScale
7484
def __init__(
7585
self,
76-
dim,
77-
num_heads,
78-
mlp_ratio=4.,
79-
qkv_bias=False,
80-
proj_drop=0.,
81-
attn_drop=0.,
82-
drop_path=0.,
83-
act_layer=nn.GELU,
84-
norm_layer=nn.LayerNorm,
85-
attn_block=ClassAttn,
86-
mlp_block=Mlp,
87-
init_values=1e-4,
86+
dim: int,
87+
num_heads: int,
88+
mlp_ratio: float = 4.,
89+
qkv_bias: bool = False,
90+
proj_drop: float = 0.,
91+
attn_drop: float = 0.,
92+
drop_path: float = 0.,
93+
act_layer: Type[nn.Module] = nn.GELU,
94+
norm_layer: Type[nn.Module] = nn.LayerNorm,
95+
attn_block: Type[nn.Module] = ClassAttn,
96+
mlp_block: Type[nn.Module] = Mlp,
97+
init_values: float = 1e-4,
98+
device=None,
99+
dtype=None,
88100
):
89101
super().__init__()
90-
self.norm1 = norm_layer(dim)
102+
dd = {'device': device, 'dtype': dtype}
103+
self.norm1 = norm_layer(dim, **dd)
91104
self.attn = attn_block(
92105
dim,
93106
num_heads=num_heads,
94107
qkv_bias=qkv_bias,
95108
attn_drop=attn_drop,
96109
proj_drop=proj_drop,
110+
**dd,
97111
)
98112
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99-
self.norm2 = norm_layer(dim)
113+
self.norm2 = norm_layer(dim, **dd)
100114
mlp_hidden_dim = int(dim * mlp_ratio)
101115
self.mlp = mlp_block(
102116
in_features=dim,
103117
hidden_features=mlp_hidden_dim,
104118
act_layer=act_layer,
105119
drop=proj_drop,
120+
**dd,
106121
)
107-
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
108-
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
122+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
123+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
109124

110125
def forward(self, x, x_cls):
111126
u = torch.cat((x_cls, x), dim=1)
@@ -117,22 +132,32 @@ def forward(self, x, x_cls):
117132
class TalkingHeadAttn(nn.Module):
118133
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
119134
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
120-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135+
def __init__(
136+
self,
137+
dim: int,
138+
num_heads: int = 8,
139+
qkv_bias: bool = False,
140+
attn_drop: float = 0.,
141+
proj_drop: float = 0.,
142+
device=None,
143+
dtype=None,
144+
):
121145
super().__init__()
146+
dd = {'device': device, 'dtype': dtype}
122147

123148
self.num_heads = num_heads
124149

125150
head_dim = dim // num_heads
126151

127152
self.scale = head_dim ** -0.5
128153

129-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
154+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
130155
self.attn_drop = nn.Dropout(attn_drop)
131156

132-
self.proj = nn.Linear(dim, dim)
157+
self.proj = nn.Linear(dim, dim, **dd)
133158

134-
self.proj_l = nn.Linear(num_heads, num_heads)
135-
self.proj_w = nn.Linear(num_heads, num_heads)
159+
self.proj_l = nn.Linear(num_heads, num_heads, **dd)
160+
self.proj_w = nn.Linear(num_heads, num_heads, **dd)
136161

137162
self.proj_drop = nn.Dropout(proj_drop)
138163

@@ -161,39 +186,44 @@ class LayerScaleBlock(nn.Module):
161186
# with slight modifications to add layerScale
162187
def __init__(
163188
self,
164-
dim,
165-
num_heads,
166-
mlp_ratio=4.,
167-
qkv_bias=False,
168-
proj_drop=0.,
169-
attn_drop=0.,
170-
drop_path=0.,
171-
act_layer=nn.GELU,
172-
norm_layer=nn.LayerNorm,
173-
attn_block=TalkingHeadAttn,
174-
mlp_block=Mlp,
175-
init_values=1e-4,
189+
dim: int,
190+
num_heads: int,
191+
mlp_ratio: float = 4.,
192+
qkv_bias: bool = False,
193+
proj_drop: float = 0.,
194+
attn_drop: float = 0.,
195+
drop_path: float = 0.,
196+
act_layer: Type[nn.Module] = nn.GELU,
197+
norm_layer: Type[nn.Module] = nn.LayerNorm,
198+
attn_block: Type[nn.Module] = TalkingHeadAttn,
199+
mlp_block: Type[nn.Module] = Mlp,
200+
init_values: float = 1e-4,
201+
device=None,
202+
dtype=None,
176203
):
177204
super().__init__()
178-
self.norm1 = norm_layer(dim)
205+
dd = {'device': device, 'dtype': dtype}
206+
self.norm1 = norm_layer(dim, **dd)
179207
self.attn = attn_block(
180208
dim,
181209
num_heads=num_heads,
182210
qkv_bias=qkv_bias,
183211
attn_drop=attn_drop,
184212
proj_drop=proj_drop,
213+
**dd,
185214
)
186215
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
187-
self.norm2 = norm_layer(dim)
216+
self.norm2 = norm_layer(dim, **dd)
188217
mlp_hidden_dim = int(dim * mlp_ratio)
189218
self.mlp = mlp_block(
190219
in_features=dim,
191220
hidden_features=mlp_hidden_dim,
192221
act_layer=act_layer,
193222
drop=proj_drop,
223+
**dd,
194224
)
195-
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
196-
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
225+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
226+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
197227

198228
def forward(self, x):
199229
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
@@ -206,35 +236,38 @@ class Cait(nn.Module):
206236
# with slight modifications to adapt to our cait models
207237
def __init__(
208238
self,
209-
img_size=224,
210-
patch_size=16,
211-
in_chans=3,
212-
num_classes=1000,
213-
global_pool='token',
214-
embed_dim=768,
215-
depth=12,
216-
num_heads=12,
217-
mlp_ratio=4.,
218-
qkv_bias=True,
219-
drop_rate=0.,
220-
pos_drop_rate=0.,
221-
proj_drop_rate=0.,
222-
attn_drop_rate=0.,
223-
drop_path_rate=0.,
224-
block_layers=LayerScaleBlock,
225-
block_layers_token=LayerScaleBlockClassAttn,
226-
patch_layer=PatchEmbed,
227-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
228-
act_layer=nn.GELU,
229-
attn_block=TalkingHeadAttn,
230-
mlp_block=Mlp,
231-
init_values=1e-4,
232-
attn_block_token_only=ClassAttn,
233-
mlp_block_token_only=Mlp,
234-
depth_token_only=2,
235-
mlp_ratio_token_only=4.0
239+
img_size: int = 224,
240+
patch_size: int = 16,
241+
in_chans: int = 3,
242+
num_classes: int = 1000,
243+
global_pool: str = 'token',
244+
embed_dim: int = 768,
245+
depth: int = 12,
246+
num_heads: int = 12,
247+
mlp_ratio: float = 4.,
248+
qkv_bias: bool = True,
249+
drop_rate: float = 0.,
250+
pos_drop_rate: float = 0.,
251+
proj_drop_rate: float = 0.,
252+
attn_drop_rate: float = 0.,
253+
drop_path_rate: float = 0.,
254+
block_layers: Type[nn.Module] = LayerScaleBlock,
255+
block_layers_token: Type[nn.Module] = LayerScaleBlockClassAttn,
256+
patch_layer: Type[nn.Module] = PatchEmbed,
257+
norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
258+
act_layer: Type[nn.Module] = nn.GELU,
259+
attn_block: Type[nn.Module] = TalkingHeadAttn,
260+
mlp_block: Type[nn.Module] = Mlp,
261+
init_values: float = 1e-4,
262+
attn_block_token_only: Type[nn.Module] = ClassAttn,
263+
mlp_block_token_only: Type[nn.Module] = Mlp,
264+
depth_token_only: int = 2,
265+
mlp_ratio_token_only: float = 4.0,
266+
device=None,
267+
dtype=None,
236268
):
237269
super().__init__()
270+
dd = {'device': device, 'dtype': dtype}
238271
assert global_pool in ('', 'token', 'avg')
239272

240273
self.num_classes = num_classes
@@ -247,12 +280,13 @@ def __init__(
247280
patch_size=patch_size,
248281
in_chans=in_chans,
249282
embed_dim=embed_dim,
283+
**dd,
250284
)
251285
num_patches = self.patch_embed.num_patches
252286
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
253287

254-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
255-
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
288+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
289+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
256290
self.pos_drop = nn.Dropout(p=pos_drop_rate)
257291

258292
dpr = [drop_path_rate for i in range(depth)]
@@ -269,6 +303,7 @@ def __init__(
269303
attn_block=attn_block,
270304
mlp_block=mlp_block,
271305
init_values=init_values,
306+
**dd,
272307
) for i in range(depth)])
273308
self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
274309

@@ -282,12 +317,13 @@ def __init__(
282317
attn_block=attn_block_token_only,
283318
mlp_block=mlp_block_token_only,
284319
init_values=init_values,
320+
**dd,
285321
) for _ in range(depth_token_only)])
286322

287-
self.norm = norm_layer(embed_dim)
323+
self.norm = norm_layer(embed_dim, **dd)
288324

289325
self.head_drop = nn.Dropout(drop_rate)
290-
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
326+
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
291327

292328
trunc_normal_(self.pos_embed, std=.02)
293329
trunc_normal_(self.cls_token, std=.02)

0 commit comments

Comments
 (0)