1616# Licensed under the MIT License.
1717
1818from functools import partial
19- from typing import Any , Dict , List , Optional , Set , Tuple , Union
19+ from typing import Any , Dict , List , Optional , Set , Tuple , Type , Union
2020
2121import torch
2222import torch .nn as nn
3333
3434
3535class Partial_conv3 (nn .Module ):
36- def __init__ (self , dim : int , n_div : int , forward : str ):
36+ def __init__ (self , dim : int , n_div : int , forward : str , device = None , dtype = None ):
37+ dd = {'device' : device , 'dtype' : dtype }
3738 super ().__init__ ()
3839 self .dim_conv3 = dim // n_div
3940 self .dim_untouched = dim - self .dim_conv3
40- self .partial_conv3 = nn .Conv2d (self .dim_conv3 , self .dim_conv3 , 3 , 1 , 1 , bias = False )
41+ self .partial_conv3 = nn .Conv2d (self .dim_conv3 , self .dim_conv3 , 3 , 1 , 1 , bias = False , ** dd )
4142
4243 if forward == 'slicing' :
4344 self .forward = self .forward_slicing
@@ -68,25 +69,28 @@ def __init__(
6869 mlp_ratio : float ,
6970 drop_path : float ,
7071 layer_scale_init_value : float ,
71- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
72- norm_layer : LayerType = nn .BatchNorm2d ,
72+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
73+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
7374 pconv_fw_type : str = 'split_cat' ,
75+ device = None ,
76+ dtype = None ,
7477 ):
78+ dd = {'device' : device , 'dtype' : dtype }
7579 super ().__init__ ()
7680 mlp_hidden_dim = int (dim * mlp_ratio )
7781
7882 self .mlp = nn .Sequential (* [
79- nn .Conv2d (dim , mlp_hidden_dim , 1 , bias = False ),
80- norm_layer (mlp_hidden_dim ),
83+ nn .Conv2d (dim , mlp_hidden_dim , 1 , bias = False , ** dd ),
84+ norm_layer (mlp_hidden_dim , ** dd ),
8185 act_layer (),
82- nn .Conv2d (mlp_hidden_dim , dim , 1 , bias = False ),
86+ nn .Conv2d (mlp_hidden_dim , dim , 1 , bias = False , ** dd ),
8387 ])
8488
85- self .spatial_mixing = Partial_conv3 (dim , n_div , pconv_fw_type )
89+ self .spatial_mixing = Partial_conv3 (dim , n_div , pconv_fw_type , ** dd )
8690
8791 if layer_scale_init_value > 0 :
8892 self .layer_scale = nn .Parameter (
89- layer_scale_init_value * torch .ones ((dim )), requires_grad = True )
93+ layer_scale_init_value * torch .ones ((dim ), ** dd ), requires_grad = True )
9094 else :
9195 self .layer_scale = None
9296
@@ -112,12 +116,15 @@ def __init__(
112116 mlp_ratio : float ,
113117 drop_path : float ,
114118 layer_scale_init_value : float ,
115- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
116- norm_layer : LayerType = nn .BatchNorm2d ,
119+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
120+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
117121 pconv_fw_type : str = 'split_cat' ,
118122 use_merge : bool = True ,
119123 merge_size : Union [int , Tuple [int , int ]] = 2 ,
124+ device = None ,
125+ dtype = None ,
120126 ):
127+ dd = {'device' : device , 'dtype' : dtype }
121128 super ().__init__ ()
122129 self .grad_checkpointing = False
123130 self .blocks = nn .Sequential (* [
@@ -130,13 +137,15 @@ def __init__(
130137 norm_layer = norm_layer ,
131138 act_layer = act_layer ,
132139 pconv_fw_type = pconv_fw_type ,
140+ ** dd ,
133141 )
134142 for i in range (depth )
135143 ])
136144 self .downsample = PatchMerging (
137145 dim = dim // 2 ,
138146 patch_size = merge_size ,
139147 norm_layer = norm_layer ,
148+ ** dd ,
140149 ) if use_merge else nn .Identity ()
141150
142151 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -154,11 +163,14 @@ def __init__(
154163 in_chans : int ,
155164 embed_dim : int ,
156165 patch_size : Union [int , Tuple [int , int ]] = 4 ,
157- norm_layer : LayerType = nn .BatchNorm2d ,
166+ norm_layer : Type [nn .Module ] = nn .BatchNorm2d ,
167+ device = None ,
168+ dtype = None ,
158169 ):
170+ dd = {'device' : device , 'dtype' : dtype }
159171 super ().__init__ ()
160- self .proj = nn .Conv2d (in_chans , embed_dim , patch_size , patch_size , bias = False )
161- self .norm = norm_layer (embed_dim )
172+ self .proj = nn .Conv2d (in_chans , embed_dim , patch_size , patch_size , bias = False , ** dd )
173+ self .norm = norm_layer (embed_dim , ** dd )
162174
163175 def forward (self , x : torch .Tensor ) -> torch .Tensor :
164176 return self .norm (self .proj (x ))
@@ -169,11 +181,14 @@ def __init__(
169181 self ,
170182 dim : int ,
171183 patch_size : Union [int , Tuple [int , int ]] = 2 ,
172- norm_layer : LayerType = nn .BatchNorm2d ,
184+ norm_layer : Type [nn .Module ] = nn .BatchNorm2d ,
185+ device = None ,
186+ dtype = None ,
173187 ):
188+ dd = {'device' : device , 'dtype' : dtype }
174189 super ().__init__ ()
175- self .reduction = nn .Conv2d (dim , 2 * dim , patch_size , patch_size , bias = False )
176- self .norm = norm_layer (2 * dim )
190+ self .reduction = nn .Conv2d (dim , 2 * dim , patch_size , patch_size , bias = False , ** dd )
191+ self .norm = norm_layer (2 * dim , ** dd )
177192
178193 def forward (self , x : torch .Tensor ) -> torch .Tensor :
179194 return self .norm (self .reduction (x ))
@@ -196,11 +211,14 @@ def __init__(
196211 drop_rate : float = 0. ,
197212 drop_path_rate : float = 0.1 ,
198213 layer_scale_init_value : float = 0. ,
199- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
200- norm_layer : LayerType = nn .BatchNorm2d ,
214+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
215+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
201216 pconv_fw_type : str = 'split_cat' ,
217+ device = None ,
218+ dtype = None ,
202219 ):
203220 super ().__init__ ()
221+ dd = {'device' : device , 'dtype' : dtype }
204222 assert pconv_fw_type in ('split_cat' , 'slicing' ,)
205223 self .num_classes = num_classes
206224 self .drop_rate = drop_rate
@@ -214,9 +232,10 @@ def __init__(
214232 embed_dim = embed_dim ,
215233 patch_size = patch_size ,
216234 norm_layer = norm_layer if patch_norm else nn .Identity ,
235+ ** dd ,
217236 )
218237 # stochastic depth decay rule
219- dpr = calculate_drop_path_rates (drop_path_rate , sum ( depths ) )
238+ dpr = calculate_drop_path_rates (drop_path_rate , depths , stagewise = True )
220239
221240 # build layers
222241 stages_list = []
@@ -227,13 +246,14 @@ def __init__(
227246 depth = depths [i ],
228247 n_div = n_div ,
229248 mlp_ratio = mlp_ratio ,
230- drop_path = dpr [sum ( depths [: i ]): sum ( depths [: i + 1 ]) ],
249+ drop_path = dpr [i ],
231250 layer_scale_init_value = layer_scale_init_value ,
232251 norm_layer = norm_layer ,
233252 act_layer = act_layer ,
234253 pconv_fw_type = pconv_fw_type ,
235254 use_merge = False if i == 0 else True ,
236255 merge_size = merge_size ,
256+ ** dd ,
237257 )
238258 stages_list .append (stage )
239259 self .feature_info += [dict (num_chs = dim , reduction = 2 ** (i + 2 ), module = f'stages.{ i } ' )]
@@ -243,10 +263,10 @@ def __init__(
243263 self .num_features = prev_chs = int (embed_dim * 2 ** (self .num_stages - 1 ))
244264 self .head_hidden_size = out_chs = feature_dim # 1280
245265 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
246- self .conv_head = nn .Conv2d (prev_chs , out_chs , 1 , 1 , 0 , bias = False )
266+ self .conv_head = nn .Conv2d (prev_chs , out_chs , 1 , 1 , 0 , bias = False , ** dd )
247267 self .act = act_layer ()
248268 self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
249- self .classifier = Linear (out_chs , num_classes , bias = True ) if num_classes > 0 else nn .Identity ()
269+ self .classifier = Linear (out_chs , num_classes , bias = True , ** dd ) if num_classes > 0 else nn .Identity ()
250270 self ._initialize_weights ()
251271
252272 def _initialize_weights (self ):
@@ -285,12 +305,13 @@ def set_grad_checkpointing(self, enable=True):
285305 def get_classifier (self ) -> nn .Module :
286306 return self .classifier
287307
288- def reset_classifier (self , num_classes : int , global_pool : str = 'avg' ):
308+ def reset_classifier (self , num_classes : int , global_pool : str = 'avg' , device = None , dtype = None ):
309+ dd = {'device' : device , 'dtype' : dtype }
289310 self .num_classes = num_classes
290311 # cannot meaningfully change pooling of efficient head after creation
291312 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
292313 self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
293- self .classifier = Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
314+ self .classifier = Linear (self .head_hidden_size , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
294315
295316 def forward_intermediates (
296317 self ,
0 commit comments