diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 797f459f5cb..f3f86c25cd2 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,6 +1,8 @@ import torch +from torch import Tensor import torch.nn as nn from .utils import load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -21,22 +23,31 @@ } -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) -def conv1x1(in_planes, out_planes, stride=1): +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -53,7 +64,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -79,10 +90,19 @@ class Bottleneck(nn.Module): # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -98,7 +118,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -123,9 +143,17 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -170,11 +198,12 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -198,7 +227,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) - def _forward_impl(self, x): + def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) @@ -216,11 +245,18 @@ def _forward_impl(self, x): return x - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _resnet(arch, block, layers, pretrained, progress, **kwargs): +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], @@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): return model -def resnet18(pretrained=False, progress=True, **kwargs): +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ @@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet34(pretrained=False, progress=True, **kwargs): +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ @@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet50(pretrained=False, progress=True, **kwargs): +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ @@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet101(pretrained=False, progress=True, **kwargs): +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ @@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet152(pretrained=False, progress=True, **kwargs): +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ @@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs): **kwargs) -def resnext50_32x4d(pretrained=False, progress=True, **kwargs): +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def resnext101_32x8d(pretrained=False, progress=True, **kwargs): +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def wide_resnet50_2(pretrained=False, progress=True, **kwargs): +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ @@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): pretrained, progress, **kwargs) -def wide_resnet101_2(pretrained=False, progress=True, **kwargs): +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_