diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index fc69188ef7a..faf3b3bc4a8 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,4 +1,6 @@ +from torch import Tensor import torch.nn as nn +from typing import Tuple, Optional, Callable, List, Type, Any, Union from ..._internally_replaced_utils import load_state_dict_from_url @@ -13,12 +15,14 @@ class Conv3DSimple(nn.Conv3d): - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): + def __init__( + self, + in_planes: int, + out_planes: int, + midplanes: Optional[int] = None, + stride: int = 1, + padding: int = 1 + ) -> None: super(Conv3DSimple, self).__init__( in_channels=in_planes, @@ -29,18 +33,20 @@ def __init__(self, bias=False) @staticmethod - def get_downsample_stride(stride): + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: return stride, stride, stride class Conv2Plus1D(nn.Sequential): - def __init__(self, - in_planes, - out_planes, - midplanes, - stride=1, - padding=1): + def __init__( + self, + in_planes: int, + out_planes: int, + midplanes: int, + stride: int = 1, + padding: int = 1 + ) -> None: super(Conv2Plus1D, self).__init__( nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), @@ -52,18 +58,20 @@ def __init__(self, bias=False)) @staticmethod - def get_downsample_stride(stride): + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: return stride, stride, stride class Conv3DNoTemporal(nn.Conv3d): - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): + def __init__( + self, + in_planes: int, + out_planes: int, + midplanes: Optional[int] = None, + stride: int = 1, + padding: int = 1 + ) -> None: super(Conv3DNoTemporal, self).__init__( in_channels=in_planes, @@ -74,7 +82,7 @@ def __init__(self, bias=False) @staticmethod - def get_downsample_stride(stride): + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: return 1, stride, stride @@ -82,7 +90,14 @@ class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) super(BasicBlock, self).__init__() @@ -99,7 +114,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: residual = x out = self.conv1(x) @@ -116,7 +131,14 @@ def forward(self, x): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: super(Bottleneck, self).__init__() midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) @@ -143,7 +165,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: residual = x out = self.conv1(x) @@ -162,7 +184,7 @@ def forward(self, x): class BasicStem(nn.Sequential): """The default conv-batchnorm-relu stem """ - def __init__(self): + def __init__(self) -> None: super(BasicStem, self).__init__( nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False), @@ -173,7 +195,7 @@ def __init__(self): class R2Plus1dStem(nn.Sequential): """R(2+1)D stem is different than the default one as it uses separated 3D convolution """ - def __init__(self): + def __init__(self) -> None: super(R2Plus1dStem, self).__init__( nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), @@ -189,16 +211,23 @@ def __init__(self): class VideoResNet(nn.Module): - def __init__(self, block, conv_makers, layers, - stem, num_classes=400, - zero_init_residual=False): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + num_classes: int = 400, + zero_init_residual: bool = False, + ) -> None: """Generic resnet video generator. Args: - block (nn.Module): resnet building block - conv_makers (list(functions)): generator function for each layer + block (Type[Union[BasicBlock, Bottleneck]]): resnet building block + conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator + function for each layer layers (List[int]): number of blocks per layer - stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + stem (Callable[..., nn.Module]): module specifying the ResNet stem. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. """ @@ -221,9 +250,9 @@ def __init__(self, block, conv_makers, layers, if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) + nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type] - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) x = self.layer1(x) @@ -238,7 +267,14 @@ def forward(self, x): return x - def _make_layer(self, block, conv_builder, planes, blocks, stride=1): + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], + planes: int, + blocks: int, + stride: int = 1 + ) -> nn.Sequential: downsample = None if stride != 1 or self.inplanes != planes * block.expansion: @@ -257,7 +293,7 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1): return nn.Sequential(*layers) - def _initialize_weights(self): + def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', @@ -272,7 +308,7 @@ def _initialize_weights(self): nn.init.constant_(m.bias, 0) -def _video_resnet(arch, pretrained=False, progress=True, **kwargs): +def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: model = VideoResNet(**kwargs) if pretrained: @@ -282,7 +318,7 @@ def _video_resnet(arch, pretrained=False, progress=True, **kwargs): return model -def r3d_18(pretrained=False, progress=True, **kwargs): +def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248 @@ -302,7 +338,7 @@ def r3d_18(pretrained=False, progress=True, **kwargs): stem=BasicStem, **kwargs) -def mc3_18(pretrained=False, progress=True, **kwargs): +def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248 @@ -316,12 +352,12 @@ def mc3_18(pretrained=False, progress=True, **kwargs): return _video_resnet('mc3_18', pretrained, progress, block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] layers=[2, 2, 2, 2], stem=BasicStem, **kwargs) -def r2plus1d_18(pretrained=False, progress=True, **kwargs): +def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248