@@ -37,29 +37,35 @@ def forward(self, x: Tensor) -> Tensor:
37
37
return x
38
38
39
39
40
+ class Permute (nn .Module ):
41
+ def __init__ (self , dims : List [int ]):
42
+ super ().__init__ ()
43
+ self .dims = dims
44
+
45
+ def forward (self , x ):
46
+ return torch .permute (x , self .dims )
47
+
48
+
40
49
class CNBlock (nn .Module ):
41
50
def __init__ (
42
- self , dim , layer_scale : float , stochastic_depth_prob : float , norm_layer : Callable [..., nn .Module ]
51
+ self ,
52
+ dim ,
53
+ layer_scale : float ,
54
+ stochastic_depth_prob : float ,
55
+ norm_layer : Optional [Callable [..., nn .Module ]] = None ,
43
56
) -> None :
44
57
super ().__init__ ()
58
+ if norm_layer is None :
59
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 )
60
+
45
61
self .block = nn .Sequential (
46
- ConvNormActivation (
47
- dim ,
48
- dim ,
49
- kernel_size = 7 ,
50
- groups = dim ,
51
- norm_layer = norm_layer ,
52
- activation_layer = None ,
53
- bias = True ,
54
- ),
55
- ConvNormActivation (dim , 4 * dim , kernel_size = 1 , norm_layer = None , activation_layer = nn .GELU , inplace = None ),
56
- ConvNormActivation (
57
- 4 * dim ,
58
- dim ,
59
- kernel_size = 1 ,
60
- norm_layer = None ,
61
- activation_layer = None ,
62
- ),
62
+ nn .Conv2d (dim , dim , kernel_size = 7 , padding = 3 , groups = dim , bias = True ),
63
+ Permute ([0 , 2 , 3 , 1 ]),
64
+ norm_layer (dim ),
65
+ nn .Linear (in_features = dim , out_features = 4 * dim , bias = True ),
66
+ nn .GELU (),
67
+ nn .Linear (in_features = 4 * dim , out_features = dim , bias = True ),
68
+ Permute ([0 , 3 , 1 , 2 ]),
63
69
)
64
70
self .layer_scale = nn .Parameter (torch .ones (dim , 1 , 1 ) * layer_scale )
65
71
self .stochastic_depth = StochasticDepth (stochastic_depth_prob , "row" )
@@ -142,7 +148,7 @@ def __init__(
142
148
for _ in range (cnf .num_layers ):
143
149
# adjust stochastic depth probability based on the depth of the stage block
144
150
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0 )
145
- stage .append (block (cnf .input_channels , layer_scale , sd_prob , norm_layer ))
151
+ stage .append (block (cnf .input_channels , layer_scale , sd_prob ))
146
152
stage_block_id += 1
147
153
layers .append (nn .Sequential (* stage ))
148
154
if cnf .out_channels is not None :
@@ -213,7 +219,7 @@ def _convnext(
213
219
214
220
class ConvNeXt_Tiny_Weights (WeightsEnum ):
215
221
IMAGENET1K_V1 = Weights (
216
- url = "https://download.pytorch.org/models/convnext_tiny-47b116bd.pth" ,
222
+ url = "https://download.pytorch.org/models/convnext_tiny-47b116bd.pth" , # TODO: repackage
217
223
transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 236 ),
218
224
meta = {
219
225
** _COMMON_META ,
@@ -227,7 +233,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
227
233
228
234
class ConvNeXt_Small_Weights (WeightsEnum ):
229
235
IMAGENET1K_V1 = Weights (
230
- url = "https://download.pytorch.org/models/convnext_small-9aa23d28.pth" ,
236
+ url = "https://download.pytorch.org/models/convnext_small-9aa23d28.pth" , # TODO: repackage
231
237
transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 230 ),
232
238
meta = {
233
239
** _COMMON_META ,
@@ -241,7 +247,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
241
247
242
248
class ConvNeXt_Base_Weights (WeightsEnum ):
243
249
IMAGENET1K_V1 = Weights (
244
- url = "https://download.pytorch.org/models/convnext_base-3b9f985d.pth" ,
250
+ url = "https://download.pytorch.org/models/convnext_base-3b9f985d.pth" , # TODO: repackage
245
251
transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
246
252
meta = {
247
253
** _COMMON_META ,
@@ -255,7 +261,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
255
261
256
262
class ConvNeXt_Large_Weights (WeightsEnum ):
257
263
IMAGENET1K_V1 = Weights (
258
- url = "https://download.pytorch.org/models/convnext_large-d73f62ac.pth" ,
264
+ url = "https://download.pytorch.org/models/convnext_large-d73f62ac.pth" , # TODO: repackage
259
265
transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
260
266
meta = {
261
267
** _COMMON_META ,
0 commit comments