3
3
import warnings
4
4
from dataclasses import dataclass
5
5
from functools import partial
6
- from typing import Any , Callable , Optional , List , Sequence
6
+ from typing import Any , Callable , Optional , List , Sequence , Tuple , Union
7
7
8
8
import torch
9
9
from torch import nn , Tensor
25
25
"efficientnet_b5" ,
26
26
"efficientnet_b6" ,
27
27
"efficientnet_b7" ,
28
+ "efficientnet_v2_s" ,
29
+ "efficientnet_v2_m" ,
30
+ "efficientnet_v2_l" ,
28
31
]
29
32
30
33
@@ -67,9 +70,9 @@ def __init__(
67
70
input_channels : int ,
68
71
out_channels : int ,
69
72
num_layers : int ,
70
- width_mult : float ,
71
- depth_mult : float ,
72
- block : Optional [Callable [..., nn .Module ]] = None
73
+ width_mult : float = 1.0 ,
74
+ depth_mult : float = 1.0 ,
75
+ block : Optional [Callable [..., nn .Module ]] = None ,
73
76
) -> None :
74
77
input_channels = self .adjust_channels (input_channels , width_mult )
75
78
out_channels = self .adjust_channels (out_channels , width_mult )
@@ -93,7 +96,7 @@ def __init__(
93
96
input_channels : int ,
94
97
out_channels : int ,
95
98
num_layers : int ,
96
- block : Optional [Callable [..., nn .Module ]] = None
99
+ block : Optional [Callable [..., nn .Module ]] = None ,
97
100
) -> None :
98
101
if block is None :
99
102
block = FusedMBConv
@@ -232,22 +235,24 @@ def forward(self, input: Tensor) -> Tensor:
232
235
class EfficientNet (nn .Module ):
233
236
def __init__ (
234
237
self ,
235
- inverted_residual_setting : List [ MBConvConfig ],
238
+ inverted_residual_setting : Sequence [ Union [ MBConvConfig , FusedMBConvConfig ] ],
236
239
dropout : float ,
237
240
stochastic_depth_prob : float = 0.2 ,
238
241
num_classes : int = 1000 ,
239
242
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
243
+ last_channel : Optional [int ] = None ,
240
244
** kwargs : Any ,
241
245
) -> None :
242
246
"""
243
247
EfficientNet V1 and V2 main class
244
248
245
249
Args:
246
- inverted_residual_setting (List[ MBConvConfig]): Network structure
250
+ inverted_residual_setting (Sequence[Union[ MBConvConfig, FusedMBConvConfig] ]): Network structure
247
251
dropout (float): The droupout probability
248
252
stochastic_depth_prob (float): The stochastic depth probability
249
253
num_classes (int): Number of classes
250
254
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
255
+ last_channel (int): The number of channels on the penultimate layer
251
256
"""
252
257
super ().__init__ ()
253
258
_log_api_usage_once (self )
@@ -307,8 +312,7 @@ def __init__(
307
312
308
313
# building last several layers
309
314
lastconv_input_channels = inverted_residual_setting [- 1 ].out_channels
310
- is_v2 = any ([isinstance (s , FusedMBConvConfig ) for s in inverted_residual_setting ])
311
- lastconv_output_channels = 1280 if is_v2 else 4 * lastconv_input_channels
315
+ lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
312
316
layers .append (
313
317
ConvNormActivation (
314
318
lastconv_input_channels ,
@@ -355,24 +359,14 @@ def forward(self, x: Tensor) -> Tensor:
355
359
356
360
def _efficientnet (
357
361
arch : str ,
358
- width_mult : float ,
359
- depth_mult : float ,
362
+ inverted_residual_setting : Sequence [Union [MBConvConfig , FusedMBConvConfig ]],
360
363
dropout : float ,
364
+ last_channel : Optional [int ],
361
365
pretrained : bool ,
362
366
progress : bool ,
363
367
** kwargs : Any ,
364
368
) -> EfficientNet :
365
- bneck_conf = partial (MBConvConfig , width_mult = width_mult , depth_mult = depth_mult )
366
- inverted_residual_setting = [
367
- bneck_conf (1 , 3 , 1 , 32 , 16 , 1 ),
368
- bneck_conf (6 , 3 , 2 , 16 , 24 , 2 ),
369
- bneck_conf (6 , 5 , 2 , 24 , 40 , 2 ),
370
- bneck_conf (6 , 3 , 2 , 40 , 80 , 3 ),
371
- bneck_conf (6 , 5 , 1 , 80 , 112 , 3 ),
372
- bneck_conf (6 , 5 , 2 , 112 , 192 , 4 ),
373
- bneck_conf (6 , 3 , 1 , 192 , 320 , 1 ),
374
- ]
375
- model = EfficientNet (inverted_residual_setting , dropout , ** kwargs )
369
+ model = EfficientNet (inverted_residual_setting , dropout , last_channel = last_channel , ** kwargs )
376
370
if pretrained :
377
371
if model_urls .get (arch , None ) is None :
378
372
raise ValueError (f"No checkpoint is available for model type { arch } " )
@@ -381,6 +375,61 @@ def _efficientnet(
381
375
return model
382
376
383
377
378
+ def _efficientnet_conf (
379
+ arch : str ,
380
+ ** kwargs : Any ,
381
+ ) -> Tuple [Sequence [Union [MBConvConfig , FusedMBConvConfig ]], Optional [int ]]:
382
+ inverted_residual_setting : Sequence [Union [MBConvConfig , FusedMBConvConfig ]]
383
+ if arch .startswith ("efficientnet_b" ):
384
+ bneck_conf = partial (MBConvConfig , width_mult = kwargs .pop ("width_mult" ), depth_mult = kwargs .pop ("depth_mult" ))
385
+ inverted_residual_setting = [
386
+ bneck_conf (1 , 3 , 1 , 32 , 16 , 1 ),
387
+ bneck_conf (6 , 3 , 2 , 16 , 24 , 2 ),
388
+ bneck_conf (6 , 5 , 2 , 24 , 40 , 2 ),
389
+ bneck_conf (6 , 3 , 2 , 40 , 80 , 3 ),
390
+ bneck_conf (6 , 5 , 1 , 80 , 112 , 3 ),
391
+ bneck_conf (6 , 5 , 2 , 112 , 192 , 4 ),
392
+ bneck_conf (6 , 3 , 1 , 192 , 320 , 1 ),
393
+ ]
394
+ last_channel = None
395
+ elif arch .startswith ("efficientnet_v2_s" ):
396
+ inverted_residual_setting = [
397
+ FusedMBConvConfig (1 , 3 , 1 , 24 , 24 , 2 ),
398
+ FusedMBConvConfig (4 , 3 , 2 , 24 , 48 , 4 ),
399
+ FusedMBConvConfig (4 , 3 , 2 , 48 , 64 , 4 ),
400
+ MBConvConfig (4 , 3 , 2 , 64 , 128 , 6 ),
401
+ MBConvConfig (6 , 3 , 1 , 128 , 160 , 9 ),
402
+ MBConvConfig (6 , 3 , 2 , 160 , 256 , 15 ),
403
+ ]
404
+ last_channel = 1280
405
+ elif arch .startswith ("efficientnet_v2_m" ):
406
+ inverted_residual_setting = [
407
+ FusedMBConvConfig (1 , 3 , 1 , 24 , 24 , 3 ),
408
+ FusedMBConvConfig (4 , 3 , 2 , 24 , 48 , 5 ),
409
+ FusedMBConvConfig (4 , 3 , 2 , 48 , 80 , 5 ),
410
+ MBConvConfig (4 , 3 , 2 , 80 , 160 , 7 ),
411
+ MBConvConfig (6 , 3 , 1 , 160 , 176 , 14 ),
412
+ MBConvConfig (6 , 3 , 2 , 176 , 304 , 18 ),
413
+ MBConvConfig (6 , 3 , 1 , 304 , 512 , 5 ),
414
+ ]
415
+ last_channel = 1280
416
+ elif arch .startswith ("efficientnet_v2_l" ):
417
+ inverted_residual_setting = [
418
+ FusedMBConvConfig (1 , 3 , 1 , 32 , 32 , 4 ),
419
+ FusedMBConvConfig (4 , 3 , 2 , 32 , 64 , 7 ),
420
+ FusedMBConvConfig (4 , 3 , 2 , 64 , 96 , 7 ),
421
+ MBConvConfig (4 , 3 , 2 , 96 , 192 , 10 ),
422
+ MBConvConfig (6 , 3 , 1 , 192 , 224 , 19 ),
423
+ MBConvConfig (6 , 3 , 2 , 224 , 384 , 25 ),
424
+ MBConvConfig (6 , 3 , 1 , 384 , 640 , 7 ),
425
+ ]
426
+ last_channel = 1280
427
+ else :
428
+ raise ValueError (f"Unsupported model type { arch } " )
429
+
430
+ return inverted_residual_setting , last_channel
431
+
432
+
384
433
def efficientnet_b0 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
385
434
"""
386
435
Constructs a EfficientNet B0 architecture from
@@ -390,7 +439,9 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A
390
439
pretrained (bool): If True, returns a model pre-trained on ImageNet
391
440
progress (bool): If True, displays a progress bar of the download to stderr
392
441
"""
393
- return _efficientnet ("efficientnet_b0" , 1.0 , 1.0 , 0.2 , pretrained , progress , ** kwargs )
442
+ arch = "efficientnet_b0"
443
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.0 , depth_mult = 1.0 )
444
+ return _efficientnet (arch , inverted_residual_setting , 0.2 , last_channel , pretrained , progress , ** kwargs )
394
445
395
446
396
447
def efficientnet_b1 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
@@ -402,7 +453,9 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A
402
453
pretrained (bool): If True, returns a model pre-trained on ImageNet
403
454
progress (bool): If True, displays a progress bar of the download to stderr
404
455
"""
405
- return _efficientnet ("efficientnet_b1" , 1.0 , 1.1 , 0.2 , pretrained , progress , ** kwargs )
456
+ arch = "efficientnet_b1"
457
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.0 , depth_mult = 1.1 )
458
+ return _efficientnet (arch , inverted_residual_setting , 0.2 , last_channel , pretrained , progress , ** kwargs )
406
459
407
460
408
461
def efficientnet_b2 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
@@ -414,7 +467,9 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A
414
467
pretrained (bool): If True, returns a model pre-trained on ImageNet
415
468
progress (bool): If True, displays a progress bar of the download to stderr
416
469
"""
417
- return _efficientnet ("efficientnet_b2" , 1.1 , 1.2 , 0.3 , pretrained , progress , ** kwargs )
470
+ arch = "efficientnet_b2"
471
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.1 , depth_mult = 1.2 )
472
+ return _efficientnet (arch , inverted_residual_setting , 0.3 , last_channel , pretrained , progress , ** kwargs )
418
473
419
474
420
475
def efficientnet_b3 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
@@ -426,7 +481,9 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A
426
481
pretrained (bool): If True, returns a model pre-trained on ImageNet
427
482
progress (bool): If True, displays a progress bar of the download to stderr
428
483
"""
429
- return _efficientnet ("efficientnet_b3" , 1.2 , 1.4 , 0.3 , pretrained , progress , ** kwargs )
484
+ arch = "efficientnet_b3"
485
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.2 , depth_mult = 1.4 )
486
+ return _efficientnet (arch , inverted_residual_setting , 0.3 , last_channel , pretrained , progress , ** kwargs )
430
487
431
488
432
489
def efficientnet_b4 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
@@ -438,7 +495,9 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A
438
495
pretrained (bool): If True, returns a model pre-trained on ImageNet
439
496
progress (bool): If True, displays a progress bar of the download to stderr
440
497
"""
441
- return _efficientnet ("efficientnet_b4" , 1.4 , 1.8 , 0.4 , pretrained , progress , ** kwargs )
498
+ arch = "efficientnet_b4"
499
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.4 , depth_mult = 1.8 )
500
+ return _efficientnet (arch , inverted_residual_setting , 0.4 , last_channel , pretrained , progress , ** kwargs )
442
501
443
502
444
503
def efficientnet_b5 (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
@@ -450,11 +509,13 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A
450
509
pretrained (bool): If True, returns a model pre-trained on ImageNet
451
510
progress (bool): If True, displays a progress bar of the download to stderr
452
511
"""
512
+ arch = "efficientnet_b5"
513
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.6 , depth_mult = 2.2 )
453
514
return _efficientnet (
454
- "efficientnet_b5" ,
455
- 1.6 ,
456
- 2.2 ,
515
+ arch ,
516
+ inverted_residual_setting ,
457
517
0.4 ,
518
+ last_channel ,
458
519
pretrained ,
459
520
progress ,
460
521
norm_layer = partial (nn .BatchNorm2d , eps = 0.001 , momentum = 0.01 ),
@@ -471,11 +532,13 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A
471
532
pretrained (bool): If True, returns a model pre-trained on ImageNet
472
533
progress (bool): If True, displays a progress bar of the download to stderr
473
534
"""
535
+ arch = "efficientnet_b6"
536
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 1.8 , depth_mult = 2.6 )
474
537
return _efficientnet (
475
- "efficientnet_b6" ,
476
- 1.8 ,
477
- 2.6 ,
538
+ arch ,
539
+ inverted_residual_setting ,
478
540
0.5 ,
541
+ last_channel ,
479
542
pretrained ,
480
543
progress ,
481
544
norm_layer = partial (nn .BatchNorm2d , eps = 0.001 , momentum = 0.01 ),
@@ -492,13 +555,57 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A
492
555
pretrained (bool): If True, returns a model pre-trained on ImageNet
493
556
progress (bool): If True, displays a progress bar of the download to stderr
494
557
"""
558
+ arch = "efficientnet_b7"
559
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch , width_mult = 2.0 , depth_mult = 3.1 )
495
560
return _efficientnet (
496
- "efficientnet_b7" ,
497
- 2.0 ,
498
- 3.1 ,
561
+ arch ,
562
+ inverted_residual_setting ,
499
563
0.5 ,
564
+ last_channel ,
500
565
pretrained ,
501
566
progress ,
502
567
norm_layer = partial (nn .BatchNorm2d , eps = 0.001 , momentum = 0.01 ),
503
568
** kwargs ,
504
569
)
570
+
571
+
572
+ def efficientnet_v2_s (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
573
+ """
574
+ Constructs an EfficientNetV2-S architecture from
575
+ `"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
576
+
577
+ Args:
578
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
579
+ progress (bool): If True, displays a progress bar of the download to stderr
580
+ """
581
+ arch = "efficientnet_v2_s"
582
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch )
583
+ return _efficientnet (arch , inverted_residual_setting , 0.3 , last_channel , pretrained , progress , ** kwargs )
584
+
585
+
586
+ def efficientnet_v2_m (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
587
+ """
588
+ Constructs an EfficientNetV2-M architecture from
589
+ `"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
590
+
591
+ Args:
592
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
593
+ progress (bool): If True, displays a progress bar of the download to stderr
594
+ """
595
+ arch = "efficientnet_v2_m"
596
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch )
597
+ return _efficientnet (arch , inverted_residual_setting , 0.4 , last_channel , pretrained , progress , ** kwargs )
598
+
599
+
600
+ def efficientnet_v2_l (pretrained : bool = False , progress : bool = True , ** kwargs : Any ) -> EfficientNet :
601
+ """
602
+ Constructs an EfficientNetV2-L architecture from
603
+ `"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
604
+
605
+ Args:
606
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
607
+ progress (bool): If True, displays a progress bar of the download to stderr
608
+ """
609
+ arch = "efficientnet_v2_l"
610
+ inverted_residual_setting , last_channel = _efficientnet_conf (arch )
611
+ return _efficientnet (arch , inverted_residual_setting , 0.5 , last_channel , pretrained , progress , ** kwargs )
0 commit comments