Skip to content

Added typing annotations to models/segmentation #4227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchvision/models/segmentation/_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from collections import OrderedDict
from typing import Optional, Dict

from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F


class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']

def __init__(self, backbone, classifier, aux_classifier=None):
def __init__(
self,
backbone: nn.Module,
classifier: nn.Module,
aux_classifier: Optional[nn.Module] = None
) -> None:
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier

def forward(self, x):
def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
Expand Down
19 changes: 10 additions & 9 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from typing import List

from ._utils import _SimpleSegmentationModel

Expand All @@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel):


class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
def __init__(self, in_channels: int, num_classes: int) -> None:
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
Expand All @@ -38,7 +39,7 @@ def __init__(self, in_channels, num_classes):


class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
Expand All @@ -48,22 +49,22 @@ def __init__(self, in_channels, out_channels, dilation):


class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
Expand All @@ -85,9 +86,9 @@ def __init__(self, in_channels, atrous_rates, out_channels=256):
nn.ReLU(),
nn.Dropout(0.5))

def forward(self, x):
res = []
def forward(self, x: torch.Tensor) -> torch.Tensor:
_res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
2 changes: 1 addition & 1 deletion torchvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel):


class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
def __init__(self, in_channels: int, channels: int) -> None:
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
Expand Down
19 changes: 16 additions & 3 deletions torchvision/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ class LRASPP(nn.Module):
inter_channels (int, optional): the number of channels for intermediate computations.
"""

def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
def __init__(
self,
backbone: nn.Module,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int = 128
) -> None:
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)

def forward(self, input):
def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
Expand All @@ -42,7 +49,13 @@ def forward(self, input):

class LRASPPHead(nn.Module):

def __init__(self, low_channels, high_channels, num_classes, inter_channels):
def __init__(
self,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int
) -> None:
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
Expand Down
76 changes: 61 additions & 15 deletions torchvision/models/segmentation/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from torch import nn
from typing import Any, Optional
from .._utils import IntermediateLayerGetter
from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3
Expand All @@ -22,7 +24,13 @@
}


def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
def _segm_model(
name: str,
backbone_name: str,
num_classes: int,
aux: Optional[bool],
pretrained_backbone: bool = True
) -> nn.Module:
if 'resnet' in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
Expand Down Expand Up @@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
return model


def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
def _load_model(
arch_type: str,
backbone: str,
pretrained: bool,
progress: bool,
num_classes: int,
aux_loss: Optional[bool],
**kwargs: Any
) -> nn.Module:
if pretrained:
aux_loss = True
kwargs["pretrained_backbone"] = False
Expand All @@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
return model


def _load_weights(model, arch_type, backbone, progress):
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls.get(arch, None)
if model_url is None:
Expand All @@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
model.load_state_dict(state_dict)


def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand All @@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
return model


def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

Args:
Expand All @@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)


def fcn_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

Args:
Expand All @@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)


def deeplabv3_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.

Args:
Expand All @@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)


def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.

Args:
Expand All @@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)


def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

Args:
Expand All @@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)


def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
**kwargs: Any
) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.

Args:
Expand Down