diff --git a/docs/source/models.rst b/docs/source/models.rst index 7fbae2a55d1..f4188a5ad1f 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model architectures for semantic segmentation: - `FCN ResNet50, ResNet101 `_ -- `DeepLabV3 ResNet50, ResNet101 `_ +- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large `_ +- `LR-ASPP MobileNetV3-Large `_ As with image classification models, all pre-trained models expect input images normalized in the same way. The images have to be loaded in to a range of ``[0, 1]`` and then normalized using @@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4 FCN ResNet101 63.7 91.9 DeepLabV3 ResNet50 66.4 92.4 DeepLabV3 ResNet101 67.4 92.4 +DeepLabV3 MobileNetV3-Large 60.3 91.2 +LR-ASPP MobileNetV3-Large 57.9 91.2 ================================ ============= ==================== @@ -313,6 +316,13 @@ DeepLabV3 .. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50 .. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101 +.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large + + +LR-ASPP +------- + +.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large Object Detection, Instance Segmentation and Person Keypoint Detection diff --git a/hubconf.py b/hubconf.py index dec4a7fb196..097759bdd89 100644 --- a/hubconf.py +++ b/hubconf.py @@ -18,4 +18,4 @@ # segmentation from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ - deeplabv3_resnet50, deeplabv3_resnet101 + deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large diff --git a/references/segmentation/README.md b/references/segmentation/README.md index 34db88c7a3a..6e24f836624 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0. ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss ``` + +## deeplabv3_mobilenet_v3_large +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 +``` + +## lraspp_mobilenet_v3_large +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 +``` diff --git a/test/expect/ModelTester.test_deeplabv3_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_deeplabv3_mobilenet_v3_large_expect.pkl new file mode 100644 index 00000000000..58d6da6c721 Binary files /dev/null and b/test/expect/ModelTester.test_deeplabv3_mobilenet_v3_large_expect.pkl differ diff --git a/test/expect/ModelTester.test_lraspp_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_lraspp_mobilenet_v3_large_expect.pkl new file mode 100644 index 00000000000..b2aa2ca89a9 Binary files /dev/null and b/test/expect/ModelTester.test_lraspp_mobilenet_v3_large_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 14880425aed..9b26839fa0b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -61,8 +61,10 @@ def get_available_video_models(): "wide_resnet101_2", "deeplabv3_resnet50", "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", "fcn_resnet50", "fcn_resnet101", + "lraspp_mobilenet_v3_large", ) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index e9a4a7104cf..45f311d160c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -136,9 +136,9 @@ def mobilenet_backbone( ): backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features - # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] num_stages = len(stage_indices) # find the index of the layer from which we wont freeze diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 93474ae7396..c4e83fa364f 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -38,14 +38,16 @@ def __init__( groups: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None, + dilation: int = 1, ) -> None: - padding = (kernel_size - 1) // 2 + padding = (kernel_size - 1) // 2 * dilation if norm_layer is None: norm_layer = nn.BatchNorm2d if activation_layer is None: activation_layer = nn.ReLU6 super(ConvBNReLU, self).__init__( - nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, + bias=False), norm_layer(out_planes), activation_layer(inplace=True) ) @@ -88,7 +90,7 @@ def __init__( ]) self.conv = nn.Sequential(*layers) self.out_channels = oup - self.is_strided = stride > 1 + self._is_cn = stride > 1 def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index eba4823277b..a7d45264dc5 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -38,7 +38,7 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, - activation: str, stride: int, width_mult: float): + activation: str, stride: int, dilation: int, width_mult: float): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) @@ -46,6 +46,7 @@ def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out self.use_se = use_se self.use_hs = activation == "HS" self.stride = stride + self.dilation = dilation @staticmethod def adjust_channels(channels: int, width_mult: float): @@ -70,9 +71,10 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise + stride = 1 if cnf.dilation > 1 else cnf.stride layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, - activation_layer=activation_layer)) + stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: layers.append(SqueezeExcitation(cnf.expanded_channels)) @@ -82,7 +84,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels - self.is_strided = cnf.stride > 1 + self._is_cn = cnf.stride > 1 def forward(self, input: Tensor) -> Tensor: result = self.block(input) @@ -194,8 +196,7 @@ def _mobilenet_v3( return model -def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, - **kwargs: Any) -> MobileNetV3: +def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. @@ -203,40 +204,38 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr - reduced_tail (bool): If True, reduces the channel counts of all feature layers - between C4 and C5 by 2. It is used to reduce the channel redundancy in the - backbone for Detection and Segmentation. """ + # non-public config parameters + reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1 + dilation = 2 if kwargs.pop('_dilated', False) else 1 width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) - reduce_divider = 2 if reduced_tail else 1 - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, False, "RE", 1), - bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 - bneck_conf(24, 3, 72, 24, False, "RE", 1), - bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 - bneck_conf(40, 5, 120, 40, True, "RE", 1), - bneck_conf(40, 5, 120, 40, True, "RE", 1), - bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 - bneck_conf(80, 3, 200, 80, False, "HS", 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1), - bneck_conf(80, 3, 480, 112, True, "HS", 1), - bneck_conf(112, 3, 672, 112, True, "HS", 1), - bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 + bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 + bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), + bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), ] last_channel = adjust_channels(1280 // reduce_divider) # C5 return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) -def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, - **kwargs: Any) -> MobileNetV3: +def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. @@ -244,28 +243,27 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr - reduced_tail (bool): If True, reduces the channel counts of all feature layers - between C4 and C5 by 2. It is used to reduce the channel redundancy in the - backbone for Detection and Segmentation. """ + # non-public config parameters + reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1 + dilation = 2 if kwargs.pop('_dilated', False) else 1 width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) - reduce_divider = 2 if reduced_tail else 1 - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 - bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 - bneck_conf(24, 3, 88, 24, False, "RE", 1), - bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 - bneck_conf(40, 5, 240, 40, True, "HS", 1), - bneck_conf(40, 5, 240, 40, True, "HS", 1), - bneck_conf(40, 5, 120, 48, True, "HS", 1), - bneck_conf(48, 5, 144, 48, True, "HS", 1), - bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), ] last_channel = adjust_channels(1024 // reduce_divider) # C5 diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index 43c80c355ad..fb6633d7fb5 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,3 +1,4 @@ from .segmentation import * from .fcn import * from .deeplabv3 import * +from .lraspp import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index c5a7ae99e43..176b7490038 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -1,6 +1,5 @@ from collections import OrderedDict -import torch from torch import nn from torch.nn import functional as F diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py new file mode 100644 index 00000000000..44cd9b1e773 --- /dev/null +++ b/torchvision/models/segmentation/lraspp.py @@ -0,0 +1,69 @@ +from collections import OrderedDict + +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Dict + + +__all__ = ["LRASPP"] + + +class LRASPP(nn.Module): + """ + Implements a Lite R-ASPP Network for semantic segmentation from + `"Searching for MobileNetV3" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "high" for the high level feature map and "low" for the low level feature map. + low_channels (int): the number of channels of the low level features. + high_channels (int): the number of channels of the high level features. + num_classes (int): number of output classes of the model (including the background). + inter_channels (int, optional): the number of channels for intermediate computations. + """ + + def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128): + super().__init__() + self.backbone = backbone + self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) + + def forward(self, input): + features = self.backbone(input) + out = self.classifier(features) + out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) + + result = OrderedDict() + result["out"] = out + + return result + + +class LRASPPHead(nn.Module): + + def __init__(self, low_channels, high_channels, num_classes, inter_channels): + super().__init__() + self.cbr = nn.Sequential( + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True) + ) + self.scale = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.Sigmoid(), + ) + self.low_classifier = nn.Conv2d(low_channels, num_classes, 1) + self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1) + + def forward(self, input: Dict[str, Tensor]) -> Tensor: + low = input["low"] + high = input["high"] + + x = self.cbr(high) + s = self.scale(high) + x = x * s + x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False) + + return self.low_classifier(low) + self.high_classifier(x) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 158ba5e3d0e..371be9b97da 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,11 +1,14 @@ from .._utils import IntermediateLayerGetter from ..utils import load_state_dict_from_url +from .. import mobilenetv3 from .. import resnet from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead +from .lraspp import LRASPP -__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101'] +__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', + 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large'] model_urls = { @@ -13,30 +16,50 @@ 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', + 'deeplabv3_mobilenet_v3_large_coco': + 'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth', + 'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth', } -def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): - backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained_backbone, - replace_stride_with_dilation=[False, True, True]) - - return_layers = {'layer4': 'out'} +def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True): + if 'resnet' in backbone_name: + backbone = resnet.__dict__[backbone_name]( + pretrained=pretrained_backbone, + replace_stride_with_dilation=[False, True, True]) + out_layer = 'layer4' + out_inplanes = 2048 + aux_layer = 'layer3' + aux_inplanes = 1024 + elif 'mobilenet_v3' in backbone_name: + backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features + + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_layer = str(out_pos) + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_layer = str(aux_pos) + aux_inplanes = backbone[aux_pos].out_channels + else: + raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) + + return_layers = {out_layer: 'out'} if aux: - return_layers['layer3'] = 'aux' + return_layers[aux_layer] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: - inplanes = 1024 - aux_classifier = FCNHead(inplanes, num_classes) + aux_classifier = FCNHead(aux_inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } - inplanes = 2048 - classifier = model_map[name][0](inplanes, num_classes) + classifier = model_map[name][0](out_inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) @@ -46,15 +69,36 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): if pretrained: aux_loss = True - model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) + model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs) if pretrained: - arch = arch_type + '_' + backbone + '_coco' - model_url = model_urls[arch] - if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + _load_weights(model, arch_type, backbone, progress) + return model + + +def _load_weights(model, arch_type, backbone, progress): + arch = arch_type + '_' + backbone + '_coco' + model_url = model_urls.get(arch, None) + if model_url is None: + raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + else: + state_dict = load_state_dict_from_url(model_url, progress=progress) + model.load_state_dict(state_dict) + + +def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True): + backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features + + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + high_pos = stage_indices[-1] # use C5 which has output_stride = 16 + low_channels = backbone[low_pos].out_channels + high_channels = backbone[high_pos].out_channels + + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'}) + + model = LRASPP(backbone, low_channels, high_channels, num_classes) return model @@ -66,6 +110,8 @@ def fcn_resnet50(pretrained=False, progress=True, pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool): If True, it uses an auxiliary loss """ return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) @@ -78,6 +124,8 @@ def fcn_resnet101(pretrained=False, progress=True, pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool): If True, it uses an auxiliary loss """ return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) @@ -90,6 +138,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True, pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool): If True, it uses an auxiliary loss """ return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) @@ -102,5 +152,42 @@ def deeplabv3_resnet101(pretrained=False, progress=True, pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool): If True, it uses an auxiliary loss """ return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) + + +def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, + num_classes=21, aux_loss=None, **kwargs): + """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + aux_loss (bool): If True, it uses an auxiliary loss + """ + return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) + + +def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): + """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 which + contains the same classes as Pascal VOC + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + """ + if kwargs.pop("aux_loss", False): + raise NotImplementedError('This model does not use auxiliary loss') + + backbone_name = 'mobilenet_v3_large' + model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) + + if pretrained: + _load_weights(model, 'lraspp', backbone_name, progress) + + return model