99# Copyright (c) 2015-present, Facebook, Inc.
1010# All rights reserved.
1111from functools import partial
12- from typing import List , Optional , Tuple , Union
12+ from typing import List , Optional , Tuple , Union , Type , Any
1313
1414import torch
1515import torch .nn as nn
@@ -29,18 +29,28 @@ class ClassAttn(nn.Module):
2929 # with slight modifications to do CA
3030 fused_attn : torch .jit .Final [bool ]
3131
32- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
32+ def __init__ (
33+ self ,
34+ dim : int ,
35+ num_heads : int = 8 ,
36+ qkv_bias : bool = False ,
37+ attn_drop : float = 0. ,
38+ proj_drop : float = 0. ,
39+ device = None ,
40+ dtype = None ,
41+ ):
3342 super ().__init__ ()
43+ dd = {'device' : device , 'dtype' : dtype }
3444 self .num_heads = num_heads
3545 head_dim = dim // num_heads
3646 self .scale = head_dim ** - 0.5
3747 self .fused_attn = use_fused_attn ()
3848
39- self .q = nn .Linear (dim , dim , bias = qkv_bias )
40- self .k = nn .Linear (dim , dim , bias = qkv_bias )
41- self .v = nn .Linear (dim , dim , bias = qkv_bias )
49+ self .q = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
50+ self .k = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
51+ self .v = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
4252 self .attn_drop = nn .Dropout (attn_drop )
43- self .proj = nn .Linear (dim , dim )
53+ self .proj = nn .Linear (dim , dim , ** dd )
4454 self .proj_drop = nn .Dropout (proj_drop )
4555
4656 def forward (self , x ):
@@ -73,39 +83,44 @@ class LayerScaleBlockClassAttn(nn.Module):
7383 # with slight modifications to add CA and LayerScale
7484 def __init__ (
7585 self ,
76- dim ,
77- num_heads ,
78- mlp_ratio = 4. ,
79- qkv_bias = False ,
80- proj_drop = 0. ,
81- attn_drop = 0. ,
82- drop_path = 0. ,
83- act_layer = nn .GELU ,
84- norm_layer = nn .LayerNorm ,
85- attn_block = ClassAttn ,
86- mlp_block = Mlp ,
87- init_values = 1e-4 ,
86+ dim : int ,
87+ num_heads : int ,
88+ mlp_ratio : float = 4. ,
89+ qkv_bias : bool = False ,
90+ proj_drop : float = 0. ,
91+ attn_drop : float = 0. ,
92+ drop_path : float = 0. ,
93+ act_layer : Type [nn .Module ] = nn .GELU ,
94+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
95+ attn_block : Type [nn .Module ] = ClassAttn ,
96+ mlp_block : Type [nn .Module ] = Mlp ,
97+ init_values : float = 1e-4 ,
98+ device = None ,
99+ dtype = None ,
88100 ):
89101 super ().__init__ ()
90- self .norm1 = norm_layer (dim )
102+ dd = {'device' : device , 'dtype' : dtype }
103+ self .norm1 = norm_layer (dim , ** dd )
91104 self .attn = attn_block (
92105 dim ,
93106 num_heads = num_heads ,
94107 qkv_bias = qkv_bias ,
95108 attn_drop = attn_drop ,
96109 proj_drop = proj_drop ,
110+ ** dd ,
97111 )
98112 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
99- self .norm2 = norm_layer (dim )
113+ self .norm2 = norm_layer (dim , ** dd )
100114 mlp_hidden_dim = int (dim * mlp_ratio )
101115 self .mlp = mlp_block (
102116 in_features = dim ,
103117 hidden_features = mlp_hidden_dim ,
104118 act_layer = act_layer ,
105119 drop = proj_drop ,
120+ ** dd ,
106121 )
107- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim ))
108- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim ))
122+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
123+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
109124
110125 def forward (self , x , x_cls ):
111126 u = torch .cat ((x_cls , x ), dim = 1 )
@@ -117,22 +132,32 @@ def forward(self, x, x_cls):
117132class TalkingHeadAttn (nn .Module ):
118133 # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
119134 # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
120- def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
135+ def __init__ (
136+ self ,
137+ dim : int ,
138+ num_heads : int = 8 ,
139+ qkv_bias : bool = False ,
140+ attn_drop : float = 0. ,
141+ proj_drop : float = 0. ,
142+ device = None ,
143+ dtype = None ,
144+ ):
121145 super ().__init__ ()
146+ dd = {'device' : device , 'dtype' : dtype }
122147
123148 self .num_heads = num_heads
124149
125150 head_dim = dim // num_heads
126151
127152 self .scale = head_dim ** - 0.5
128153
129- self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
154+ self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias , ** dd )
130155 self .attn_drop = nn .Dropout (attn_drop )
131156
132- self .proj = nn .Linear (dim , dim )
157+ self .proj = nn .Linear (dim , dim , ** dd )
133158
134- self .proj_l = nn .Linear (num_heads , num_heads )
135- self .proj_w = nn .Linear (num_heads , num_heads )
159+ self .proj_l = nn .Linear (num_heads , num_heads , ** dd )
160+ self .proj_w = nn .Linear (num_heads , num_heads , ** dd )
136161
137162 self .proj_drop = nn .Dropout (proj_drop )
138163
@@ -161,39 +186,44 @@ class LayerScaleBlock(nn.Module):
161186 # with slight modifications to add layerScale
162187 def __init__ (
163188 self ,
164- dim ,
165- num_heads ,
166- mlp_ratio = 4. ,
167- qkv_bias = False ,
168- proj_drop = 0. ,
169- attn_drop = 0. ,
170- drop_path = 0. ,
171- act_layer = nn .GELU ,
172- norm_layer = nn .LayerNorm ,
173- attn_block = TalkingHeadAttn ,
174- mlp_block = Mlp ,
175- init_values = 1e-4 ,
189+ dim : int ,
190+ num_heads : int ,
191+ mlp_ratio : float = 4. ,
192+ qkv_bias : bool = False ,
193+ proj_drop : float = 0. ,
194+ attn_drop : float = 0. ,
195+ drop_path : float = 0. ,
196+ act_layer : Type [nn .Module ] = nn .GELU ,
197+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
198+ attn_block : Type [nn .Module ] = TalkingHeadAttn ,
199+ mlp_block : Type [nn .Module ] = Mlp ,
200+ init_values : float = 1e-4 ,
201+ device = None ,
202+ dtype = None ,
176203 ):
177204 super ().__init__ ()
178- self .norm1 = norm_layer (dim )
205+ dd = {'device' : device , 'dtype' : dtype }
206+ self .norm1 = norm_layer (dim , ** dd )
179207 self .attn = attn_block (
180208 dim ,
181209 num_heads = num_heads ,
182210 qkv_bias = qkv_bias ,
183211 attn_drop = attn_drop ,
184212 proj_drop = proj_drop ,
213+ ** dd ,
185214 )
186215 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
187- self .norm2 = norm_layer (dim )
216+ self .norm2 = norm_layer (dim , ** dd )
188217 mlp_hidden_dim = int (dim * mlp_ratio )
189218 self .mlp = mlp_block (
190219 in_features = dim ,
191220 hidden_features = mlp_hidden_dim ,
192221 act_layer = act_layer ,
193222 drop = proj_drop ,
223+ ** dd ,
194224 )
195- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim ))
196- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim ))
225+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
226+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd ))
197227
198228 def forward (self , x ):
199229 x = x + self .drop_path (self .gamma_1 * self .attn (self .norm1 (x )))
@@ -206,35 +236,38 @@ class Cait(nn.Module):
206236 # with slight modifications to adapt to our cait models
207237 def __init__ (
208238 self ,
209- img_size = 224 ,
210- patch_size = 16 ,
211- in_chans = 3 ,
212- num_classes = 1000 ,
213- global_pool = 'token' ,
214- embed_dim = 768 ,
215- depth = 12 ,
216- num_heads = 12 ,
217- mlp_ratio = 4. ,
218- qkv_bias = True ,
219- drop_rate = 0. ,
220- pos_drop_rate = 0. ,
221- proj_drop_rate = 0. ,
222- attn_drop_rate = 0. ,
223- drop_path_rate = 0. ,
224- block_layers = LayerScaleBlock ,
225- block_layers_token = LayerScaleBlockClassAttn ,
226- patch_layer = PatchEmbed ,
227- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
228- act_layer = nn .GELU ,
229- attn_block = TalkingHeadAttn ,
230- mlp_block = Mlp ,
231- init_values = 1e-4 ,
232- attn_block_token_only = ClassAttn ,
233- mlp_block_token_only = Mlp ,
234- depth_token_only = 2 ,
235- mlp_ratio_token_only = 4.0
239+ img_size : int = 224 ,
240+ patch_size : int = 16 ,
241+ in_chans : int = 3 ,
242+ num_classes : int = 1000 ,
243+ global_pool : str = 'token' ,
244+ embed_dim : int = 768 ,
245+ depth : int = 12 ,
246+ num_heads : int = 12 ,
247+ mlp_ratio : float = 4. ,
248+ qkv_bias : bool = True ,
249+ drop_rate : float = 0. ,
250+ pos_drop_rate : float = 0. ,
251+ proj_drop_rate : float = 0. ,
252+ attn_drop_rate : float = 0. ,
253+ drop_path_rate : float = 0. ,
254+ block_layers : Type [nn .Module ] = LayerScaleBlock ,
255+ block_layers_token : Type [nn .Module ] = LayerScaleBlockClassAttn ,
256+ patch_layer : Type [nn .Module ] = PatchEmbed ,
257+ norm_layer : Type [nn .Module ] = partial (nn .LayerNorm , eps = 1e-6 ),
258+ act_layer : Type [nn .Module ] = nn .GELU ,
259+ attn_block : Type [nn .Module ] = TalkingHeadAttn ,
260+ mlp_block : Type [nn .Module ] = Mlp ,
261+ init_values : float = 1e-4 ,
262+ attn_block_token_only : Type [nn .Module ] = ClassAttn ,
263+ mlp_block_token_only : Type [nn .Module ] = Mlp ,
264+ depth_token_only : int = 2 ,
265+ mlp_ratio_token_only : float = 4.0 ,
266+ device = None ,
267+ dtype = None ,
236268 ):
237269 super ().__init__ ()
270+ dd = {'device' : device , 'dtype' : dtype }
238271 assert global_pool in ('' , 'token' , 'avg' )
239272
240273 self .num_classes = num_classes
@@ -247,12 +280,13 @@ def __init__(
247280 patch_size = patch_size ,
248281 in_chans = in_chans ,
249282 embed_dim = embed_dim ,
283+ ** dd ,
250284 )
251285 num_patches = self .patch_embed .num_patches
252286 r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
253287
254- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
255- self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches , embed_dim ))
288+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim , ** dd ))
289+ self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches , embed_dim , ** dd ))
256290 self .pos_drop = nn .Dropout (p = pos_drop_rate )
257291
258292 dpr = [drop_path_rate for i in range (depth )]
@@ -269,6 +303,7 @@ def __init__(
269303 attn_block = attn_block ,
270304 mlp_block = mlp_block ,
271305 init_values = init_values ,
306+ ** dd ,
272307 ) for i in range (depth )])
273308 self .feature_info = [dict (num_chs = embed_dim , reduction = r , module = f'blocks.{ i } ' ) for i in range (depth )]
274309
@@ -282,12 +317,13 @@ def __init__(
282317 attn_block = attn_block_token_only ,
283318 mlp_block = mlp_block_token_only ,
284319 init_values = init_values ,
320+ ** dd ,
285321 ) for _ in range (depth_token_only )])
286322
287- self .norm = norm_layer (embed_dim )
323+ self .norm = norm_layer (embed_dim , ** dd )
288324
289325 self .head_drop = nn .Dropout (drop_rate )
290- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
326+ self .head = nn .Linear (embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
291327
292328 trunc_normal_ (self .pos_embed , std = .02 )
293329 trunc_normal_ (self .cls_token , std = .02 )
0 commit comments