Skip to content

Commit aa82cf1

Browse files
committed
Refactor config/builder methods and add prototype builders
1 parent b03a7ec commit aa82cf1

File tree

5 files changed

+234
-82
lines changed

5 files changed

+234
-82
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

torchvision/models/efficientnet.py

Lines changed: 143 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from dataclasses import dataclass
55
from functools import partial
6-
from typing import Any, Callable, Optional, List, Sequence
6+
from typing import Any, Callable, Optional, List, Sequence, Tuple, Union
77

88
import torch
99
from torch import nn, Tensor
@@ -25,6 +25,9 @@
2525
"efficientnet_b5",
2626
"efficientnet_b6",
2727
"efficientnet_b7",
28+
"efficientnet_v2_s",
29+
"efficientnet_v2_m",
30+
"efficientnet_v2_l",
2831
]
2932

3033

@@ -67,9 +70,9 @@ def __init__(
6770
input_channels: int,
6871
out_channels: int,
6972
num_layers: int,
70-
width_mult: float,
71-
depth_mult: float,
72-
block: Optional[Callable[..., nn.Module]] = None
73+
width_mult: float = 1.0,
74+
depth_mult: float = 1.0,
75+
block: Optional[Callable[..., nn.Module]] = None,
7376
) -> None:
7477
input_channels = self.adjust_channels(input_channels, width_mult)
7578
out_channels = self.adjust_channels(out_channels, width_mult)
@@ -93,7 +96,7 @@ def __init__(
9396
input_channels: int,
9497
out_channels: int,
9598
num_layers: int,
96-
block: Optional[Callable[..., nn.Module]] = None
99+
block: Optional[Callable[..., nn.Module]] = None,
97100
) -> None:
98101
if block is None:
99102
block = FusedMBConv
@@ -232,22 +235,24 @@ def forward(self, input: Tensor) -> Tensor:
232235
class EfficientNet(nn.Module):
233236
def __init__(
234237
self,
235-
inverted_residual_setting: List[MBConvConfig],
238+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
236239
dropout: float,
237240
stochastic_depth_prob: float = 0.2,
238241
num_classes: int = 1000,
239242
norm_layer: Optional[Callable[..., nn.Module]] = None,
243+
last_channel: Optional[int] = None,
240244
**kwargs: Any,
241245
) -> None:
242246
"""
243247
EfficientNet V1 and V2 main class
244248
245249
Args:
246-
inverted_residual_setting (List[MBConvConfig]): Network structure
250+
inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
247251
dropout (float): The droupout probability
248252
stochastic_depth_prob (float): The stochastic depth probability
249253
num_classes (int): Number of classes
250254
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
255+
last_channel (int): The number of channels on the penultimate layer
251256
"""
252257
super().__init__()
253258
_log_api_usage_once(self)
@@ -307,8 +312,7 @@ def __init__(
307312

308313
# building last several layers
309314
lastconv_input_channels = inverted_residual_setting[-1].out_channels
310-
is_v2 = any([isinstance(s, FusedMBConvConfig) for s in inverted_residual_setting])
311-
lastconv_output_channels = 1280 if is_v2 else 4 * lastconv_input_channels
315+
lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
312316
layers.append(
313317
ConvNormActivation(
314318
lastconv_input_channels,
@@ -355,24 +359,14 @@ def forward(self, x: Tensor) -> Tensor:
355359

356360
def _efficientnet(
357361
arch: str,
358-
width_mult: float,
359-
depth_mult: float,
362+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
360363
dropout: float,
364+
last_channel: Optional[int],
361365
pretrained: bool,
362366
progress: bool,
363367
**kwargs: Any,
364368
) -> EfficientNet:
365-
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
366-
inverted_residual_setting = [
367-
bneck_conf(1, 3, 1, 32, 16, 1),
368-
bneck_conf(6, 3, 2, 16, 24, 2),
369-
bneck_conf(6, 5, 2, 24, 40, 2),
370-
bneck_conf(6, 3, 2, 40, 80, 3),
371-
bneck_conf(6, 5, 1, 80, 112, 3),
372-
bneck_conf(6, 5, 2, 112, 192, 4),
373-
bneck_conf(6, 3, 1, 192, 320, 1),
374-
]
375-
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
369+
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
376370
if pretrained:
377371
if model_urls.get(arch, None) is None:
378372
raise ValueError(f"No checkpoint is available for model type {arch}")
@@ -381,6 +375,61 @@ def _efficientnet(
381375
return model
382376

383377

378+
def _efficientnet_conf(
379+
arch: str,
380+
**kwargs: Any,
381+
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
382+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
383+
if arch.startswith("efficientnet_b"):
384+
bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
385+
inverted_residual_setting = [
386+
bneck_conf(1, 3, 1, 32, 16, 1),
387+
bneck_conf(6, 3, 2, 16, 24, 2),
388+
bneck_conf(6, 5, 2, 24, 40, 2),
389+
bneck_conf(6, 3, 2, 40, 80, 3),
390+
bneck_conf(6, 5, 1, 80, 112, 3),
391+
bneck_conf(6, 5, 2, 112, 192, 4),
392+
bneck_conf(6, 3, 1, 192, 320, 1),
393+
]
394+
last_channel = None
395+
elif arch.startswith("efficientnet_v2_s"):
396+
inverted_residual_setting = [
397+
FusedMBConvConfig(1, 3, 1, 24, 24, 2),
398+
FusedMBConvConfig(4, 3, 2, 24, 48, 4),
399+
FusedMBConvConfig(4, 3, 2, 48, 64, 4),
400+
MBConvConfig(4, 3, 2, 64, 128, 6),
401+
MBConvConfig(6, 3, 1, 128, 160, 9),
402+
MBConvConfig(6, 3, 2, 160, 256, 15),
403+
]
404+
last_channel = 1280
405+
elif arch.startswith("efficientnet_v2_m"):
406+
inverted_residual_setting = [
407+
FusedMBConvConfig(1, 3, 1, 24, 24, 3),
408+
FusedMBConvConfig(4, 3, 2, 24, 48, 5),
409+
FusedMBConvConfig(4, 3, 2, 48, 80, 5),
410+
MBConvConfig(4, 3, 2, 80, 160, 7),
411+
MBConvConfig(6, 3, 1, 160, 176, 14),
412+
MBConvConfig(6, 3, 2, 176, 304, 18),
413+
MBConvConfig(6, 3, 1, 304, 512, 5),
414+
]
415+
last_channel = 1280
416+
elif arch.startswith("efficientnet_v2_l"):
417+
inverted_residual_setting = [
418+
FusedMBConvConfig(1, 3, 1, 32, 32, 4),
419+
FusedMBConvConfig(4, 3, 2, 32, 64, 7),
420+
FusedMBConvConfig(4, 3, 2, 64, 96, 7),
421+
MBConvConfig(4, 3, 2, 96, 192, 10),
422+
MBConvConfig(6, 3, 1, 192, 224, 19),
423+
MBConvConfig(6, 3, 2, 224, 384, 25),
424+
MBConvConfig(6, 3, 1, 384, 640, 7),
425+
]
426+
last_channel = 1280
427+
else:
428+
raise ValueError(f"Unsupported model type {arch}")
429+
430+
return inverted_residual_setting, last_channel
431+
432+
384433
def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
385434
"""
386435
Constructs a EfficientNet B0 architecture from
@@ -390,7 +439,9 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A
390439
pretrained (bool): If True, returns a model pre-trained on ImageNet
391440
progress (bool): If True, displays a progress bar of the download to stderr
392441
"""
393-
return _efficientnet("efficientnet_b0", 1.0, 1.0, 0.2, pretrained, progress, **kwargs)
442+
arch = "efficientnet_b0"
443+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0)
444+
return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs)
394445

395446

396447
def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
@@ -402,7 +453,9 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A
402453
pretrained (bool): If True, returns a model pre-trained on ImageNet
403454
progress (bool): If True, displays a progress bar of the download to stderr
404455
"""
405-
return _efficientnet("efficientnet_b1", 1.0, 1.1, 0.2, pretrained, progress, **kwargs)
456+
arch = "efficientnet_b1"
457+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1)
458+
return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs)
406459

407460

408461
def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
@@ -414,7 +467,9 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A
414467
pretrained (bool): If True, returns a model pre-trained on ImageNet
415468
progress (bool): If True, displays a progress bar of the download to stderr
416469
"""
417-
return _efficientnet("efficientnet_b2", 1.1, 1.2, 0.3, pretrained, progress, **kwargs)
470+
arch = "efficientnet_b2"
471+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2)
472+
return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs)
418473

419474

420475
def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
@@ -426,7 +481,9 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A
426481
pretrained (bool): If True, returns a model pre-trained on ImageNet
427482
progress (bool): If True, displays a progress bar of the download to stderr
428483
"""
429-
return _efficientnet("efficientnet_b3", 1.2, 1.4, 0.3, pretrained, progress, **kwargs)
484+
arch = "efficientnet_b3"
485+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4)
486+
return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs)
430487

431488

432489
def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
@@ -438,7 +495,9 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A
438495
pretrained (bool): If True, returns a model pre-trained on ImageNet
439496
progress (bool): If True, displays a progress bar of the download to stderr
440497
"""
441-
return _efficientnet("efficientnet_b4", 1.4, 1.8, 0.4, pretrained, progress, **kwargs)
498+
arch = "efficientnet_b4"
499+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8)
500+
return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs)
442501

443502

444503
def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
@@ -450,11 +509,13 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A
450509
pretrained (bool): If True, returns a model pre-trained on ImageNet
451510
progress (bool): If True, displays a progress bar of the download to stderr
452511
"""
512+
arch = "efficientnet_b5"
513+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2)
453514
return _efficientnet(
454-
"efficientnet_b5",
455-
1.6,
456-
2.2,
515+
arch,
516+
inverted_residual_setting,
457517
0.4,
518+
last_channel,
458519
pretrained,
459520
progress,
460521
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
@@ -471,11 +532,13 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A
471532
pretrained (bool): If True, returns a model pre-trained on ImageNet
472533
progress (bool): If True, displays a progress bar of the download to stderr
473534
"""
535+
arch = "efficientnet_b6"
536+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6)
474537
return _efficientnet(
475-
"efficientnet_b6",
476-
1.8,
477-
2.6,
538+
arch,
539+
inverted_residual_setting,
478540
0.5,
541+
last_channel,
479542
pretrained,
480543
progress,
481544
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
@@ -492,13 +555,57 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A
492555
pretrained (bool): If True, returns a model pre-trained on ImageNet
493556
progress (bool): If True, displays a progress bar of the download to stderr
494557
"""
558+
arch = "efficientnet_b7"
559+
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1)
495560
return _efficientnet(
496-
"efficientnet_b7",
497-
2.0,
498-
3.1,
561+
arch,
562+
inverted_residual_setting,
499563
0.5,
564+
last_channel,
500565
pretrained,
501566
progress,
502567
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
503568
**kwargs,
504569
)
570+
571+
572+
def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
573+
"""
574+
Constructs an EfficientNetV2-S architecture from
575+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
576+
577+
Args:
578+
pretrained (bool): If True, returns a model pre-trained on ImageNet
579+
progress (bool): If True, displays a progress bar of the download to stderr
580+
"""
581+
arch = "efficientnet_v2_s"
582+
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
583+
return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs)
584+
585+
586+
def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
587+
"""
588+
Constructs an EfficientNetV2-M architecture from
589+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
590+
591+
Args:
592+
pretrained (bool): If True, returns a model pre-trained on ImageNet
593+
progress (bool): If True, displays a progress bar of the download to stderr
594+
"""
595+
arch = "efficientnet_v2_m"
596+
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
597+
return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs)
598+
599+
600+
def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
601+
"""
602+
Constructs an EfficientNetV2-L architecture from
603+
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
604+
605+
Args:
606+
pretrained (bool): If True, returns a model pre-trained on ImageNet
607+
progress (bool): If True, displays a progress bar of the download to stderr
608+
"""
609+
arch = "efficientnet_v2_l"
610+
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
611+
return _efficientnet(arch, inverted_residual_setting, 0.5, last_channel, pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)