15
15
from ._utils import handle_legacy_interface , _ovewrite_named_param
16
16
17
17
18
- __all__ = ["ConvNeXt" , "ConvNeXt_Tiny_Weights" , "convnext_tiny" ]
18
+ __all__ = [
19
+ "ConvNeXt" ,
20
+ "ConvNeXt_Tiny_Weights" ,
21
+ "ConvNeXt_Small_Weights" ,
22
+ "ConvNeXt_Base_Weights" ,
23
+ "ConvNeXt_Large_Weights" ,
24
+ "convnext_tiny" ,
25
+ "convnext_small" ,
26
+ "convnext_base" ,
27
+ "convnext_large" ,
28
+ ]
19
29
20
30
21
31
class LayerNorm2d (nn .LayerNorm ):
22
- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
23
- self .channels_last = kwargs .pop ("channels_last" , False )
24
- super ().__init__ (* args , ** kwargs )
25
-
26
32
def forward (self , x : Tensor ) -> Tensor :
27
- # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
28
- if not self .channels_last :
29
- x = x .permute (0 , 2 , 3 , 1 )
33
+ x = x .permute (0 , 2 , 3 , 1 )
30
34
x = F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
31
- if not self .channels_last :
32
- x = x .permute (0 , 3 , 1 , 2 )
35
+ x = x .permute (0 , 3 , 1 , 2 )
33
36
return x
34
37
35
38
39
+ class Permute (nn .Module ):
40
+ def __init__ (self , dims : List [int ]):
41
+ super ().__init__ ()
42
+ self .dims = dims
43
+
44
+ def forward (self , x ):
45
+ return torch .permute (x , self .dims )
46
+
47
+
36
48
class CNBlock (nn .Module ):
37
49
def __init__ (
38
- self , dim , layer_scale : float , stochastic_depth_prob : float , norm_layer : Callable [..., nn .Module ]
50
+ self ,
51
+ dim ,
52
+ layer_scale : float ,
53
+ stochastic_depth_prob : float ,
54
+ norm_layer : Optional [Callable [..., nn .Module ]] = None ,
39
55
) -> None :
40
56
super ().__init__ ()
57
+ if norm_layer is None :
58
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 )
59
+
41
60
self .block = nn .Sequential (
42
- ConvNormActivation (
43
- dim ,
44
- dim ,
45
- kernel_size = 7 ,
46
- groups = dim ,
47
- norm_layer = norm_layer ,
48
- activation_layer = None ,
49
- bias = True ,
50
- ),
51
- ConvNormActivation (dim , 4 * dim , kernel_size = 1 , norm_layer = None , activation_layer = nn .GELU , inplace = None ),
52
- ConvNormActivation (
53
- 4 * dim ,
54
- dim ,
55
- kernel_size = 1 ,
56
- norm_layer = None ,
57
- activation_layer = None ,
58
- ),
61
+ nn .Conv2d (dim , dim , kernel_size = 7 , padding = 3 , groups = dim , bias = True ),
62
+ Permute ([0 , 2 , 3 , 1 ]),
63
+ norm_layer (dim ),
64
+ nn .Linear (in_features = dim , out_features = 4 * dim , bias = True ),
65
+ nn .GELU (),
66
+ nn .Linear (in_features = 4 * dim , out_features = dim , bias = True ),
67
+ Permute ([0 , 3 , 1 , 2 ]),
59
68
)
60
69
self .layer_scale = nn .Parameter (torch .ones (dim , 1 , 1 ) * layer_scale )
61
70
self .stochastic_depth = StochasticDepth (stochastic_depth_prob , "row" )
@@ -138,7 +147,7 @@ def __init__(
138
147
for _ in range (cnf .num_layers ):
139
148
# adjust stochastic depth probability based on the depth of the stage block
140
149
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0 )
141
- stage .append (block (cnf .input_channels , layer_scale , sd_prob , norm_layer ))
150
+ stage .append (block (cnf .input_channels , layer_scale , sd_prob ))
142
151
stage_block_id += 1
143
152
layers .append (nn .Sequential (* stage ))
144
153
if cnf .out_channels is not None :
@@ -177,30 +186,95 @@ def forward(self, x: Tensor) -> Tensor:
177
186
return self ._forward_impl (x )
178
187
179
188
189
+ def _convnext (
190
+ block_setting : List [CNBlockConfig ],
191
+ stochastic_depth_prob : float ,
192
+ weights : Optional [WeightsEnum ],
193
+ progress : bool ,
194
+ ** kwargs : Any ,
195
+ ) -> ConvNeXt :
196
+ if weights is not None :
197
+ _ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
198
+
199
+ model = ConvNeXt (block_setting , stochastic_depth_prob = stochastic_depth_prob , ** kwargs )
200
+
201
+ if weights is not None :
202
+ model .load_state_dict (weights .get_state_dict (progress = progress ))
203
+
204
+ return model
205
+
206
+
207
+ _COMMON_META = {
208
+ "task" : "image_classification" ,
209
+ "architecture" : "ConvNeXt" ,
210
+ "publication_year" : 2022 ,
211
+ "size" : (224 , 224 ),
212
+ "min_size" : (32 , 32 ),
213
+ "categories" : _IMAGENET_CATEGORIES ,
214
+ "interpolation" : InterpolationMode .BILINEAR ,
215
+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#convnext" ,
216
+ }
217
+
218
+
180
219
class ConvNeXt_Tiny_Weights (WeightsEnum ):
181
220
IMAGENET1K_V1 = Weights (
182
- url = "https://download.pytorch.org/models/convnext_tiny-47b116bd .pth" ,
221
+ url = "https://download.pytorch.org/models/convnext_tiny-983f1562 .pth" ,
183
222
transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 236 ),
184
223
meta = {
185
- "task" : "image_classification" ,
186
- "architecture" : "ConvNeXt" ,
187
- "publication_year" : 2022 ,
224
+ ** _COMMON_META ,
188
225
"num_params" : 28589128 ,
189
- "size" : (224 , 224 ),
190
- "min_size" : (32 , 32 ),
191
- "categories" : _IMAGENET_CATEGORIES ,
192
- "interpolation" : InterpolationMode .BILINEAR ,
193
- "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#convnext" ,
194
226
"acc@1" : 82.520 ,
195
227
"acc@5" : 96.146 ,
196
228
},
197
229
)
198
230
DEFAULT = IMAGENET1K_V1
199
231
200
232
233
+ class ConvNeXt_Small_Weights (WeightsEnum ):
234
+ IMAGENET1K_V1 = Weights (
235
+ url = "https://download.pytorch.org/models/convnext_small-0c510722.pth" ,
236
+ transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 230 ),
237
+ meta = {
238
+ ** _COMMON_META ,
239
+ "num_params" : 50223688 ,
240
+ "acc@1" : 83.616 ,
241
+ "acc@5" : 96.650 ,
242
+ },
243
+ )
244
+ DEFAULT = IMAGENET1K_V1
245
+
246
+
247
+ class ConvNeXt_Base_Weights (WeightsEnum ):
248
+ IMAGENET1K_V1 = Weights (
249
+ url = "https://download.pytorch.org/models/convnext_base-6075fbad.pth" ,
250
+ transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
251
+ meta = {
252
+ ** _COMMON_META ,
253
+ "num_params" : 88591464 ,
254
+ "acc@1" : 84.062 ,
255
+ "acc@5" : 96.870 ,
256
+ },
257
+ )
258
+ DEFAULT = IMAGENET1K_V1
259
+
260
+
261
+ class ConvNeXt_Large_Weights (WeightsEnum ):
262
+ IMAGENET1K_V1 = Weights (
263
+ url = "https://download.pytorch.org/models/convnext_large-ea097f82.pth" ,
264
+ transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
265
+ meta = {
266
+ ** _COMMON_META ,
267
+ "num_params" : 197767336 ,
268
+ "acc@1" : 84.414 ,
269
+ "acc@5" : 96.976 ,
270
+ },
271
+ )
272
+ DEFAULT = IMAGENET1K_V1
273
+
274
+
201
275
@handle_legacy_interface (weights = ("pretrained" , ConvNeXt_Tiny_Weights .IMAGENET1K_V1 ))
202
276
def convnext_tiny (* , weights : Optional [ConvNeXt_Tiny_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ConvNeXt :
203
- r"""ConvNeXt model architecture from the
277
+ r"""ConvNeXt Tiny model architecture from the
204
278
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
205
279
206
280
Args:
@@ -209,19 +283,57 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
209
283
"""
210
284
weights = ConvNeXt_Tiny_Weights .verify (weights )
211
285
212
- if weights is not None :
213
- _ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
214
-
215
286
block_setting = [
216
287
CNBlockConfig (96 , 192 , 3 ),
217
288
CNBlockConfig (192 , 384 , 3 ),
218
289
CNBlockConfig (384 , 768 , 9 ),
219
290
CNBlockConfig (768 , None , 3 ),
220
291
]
221
292
stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.1 )
222
- model = ConvNeXt (block_setting , stochastic_depth_prob = stochastic_depth_prob , ** kwargs )
293
+ return _convnext (block_setting , stochastic_depth_prob , weights , progress , ** kwargs )
223
294
224
- if weights is not None :
225
- model .load_state_dict (weights .get_state_dict (progress = progress ))
226
295
227
- return model
296
+ @handle_legacy_interface (weights = ("pretrained" , ConvNeXt_Small_Weights .IMAGENET1K_V1 ))
297
+ def convnext_small (
298
+ * , weights : Optional [ConvNeXt_Small_Weights ] = None , progress : bool = True , ** kwargs : Any
299
+ ) -> ConvNeXt :
300
+ weights = ConvNeXt_Small_Weights .verify (weights )
301
+
302
+ block_setting = [
303
+ CNBlockConfig (96 , 192 , 3 ),
304
+ CNBlockConfig (192 , 384 , 3 ),
305
+ CNBlockConfig (384 , 768 , 27 ),
306
+ CNBlockConfig (768 , None , 3 ),
307
+ ]
308
+ stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.4 )
309
+ return _convnext (block_setting , stochastic_depth_prob , weights , progress , ** kwargs )
310
+
311
+
312
+ @handle_legacy_interface (weights = ("pretrained" , ConvNeXt_Base_Weights .IMAGENET1K_V1 ))
313
+ def convnext_base (* , weights : Optional [ConvNeXt_Base_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ConvNeXt :
314
+ weights = ConvNeXt_Base_Weights .verify (weights )
315
+
316
+ block_setting = [
317
+ CNBlockConfig (128 , 256 , 3 ),
318
+ CNBlockConfig (256 , 512 , 3 ),
319
+ CNBlockConfig (512 , 1024 , 27 ),
320
+ CNBlockConfig (1024 , None , 3 ),
321
+ ]
322
+ stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.5 )
323
+ return _convnext (block_setting , stochastic_depth_prob , weights , progress , ** kwargs )
324
+
325
+
326
+ @handle_legacy_interface (weights = ("pretrained" , ConvNeXt_Large_Weights .IMAGENET1K_V1 ))
327
+ def convnext_large (
328
+ * , weights : Optional [ConvNeXt_Large_Weights ] = None , progress : bool = True , ** kwargs : Any
329
+ ) -> ConvNeXt :
330
+ weights = ConvNeXt_Large_Weights .verify (weights )
331
+
332
+ block_setting = [
333
+ CNBlockConfig (192 , 384 , 3 ),
334
+ CNBlockConfig (384 , 768 , 3 ),
335
+ CNBlockConfig (768 , 1536 , 27 ),
336
+ CNBlockConfig (1536 , None , 3 ),
337
+ ]
338
+ stochastic_depth_prob = kwargs .pop ("stochastic_depth_prob" , 0.5 )
339
+ return _convnext (block_setting , stochastic_depth_prob , weights , progress , ** kwargs )
0 commit comments