Skip to content

Added typing annotations to models/video #4229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 23, 2021
120 changes: 78 additions & 42 deletions torchvision/models/video/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from torch import Tensor
import torch.nn as nn
from typing import Tuple, Optional, Callable, List, Type, Any, Union

from ..._internally_replaced_utils import load_state_dict_from_url

Expand All @@ -13,12 +15,14 @@


class Conv3DSimple(nn.Conv3d):
def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):
def __init__(
self,
in_planes: int,
out_planes: int,
midplanes: Optional[int] = None,
stride: int = 1,
padding: int = 1
) -> None:

super(Conv3DSimple, self).__init__(
in_channels=in_planes,
Expand All @@ -29,18 +33,20 @@ def __init__(self,
bias=False)

@staticmethod
def get_downsample_stride(stride):
def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
return stride, stride, stride


class Conv2Plus1D(nn.Sequential):

def __init__(self,
in_planes,
out_planes,
midplanes,
stride=1,
padding=1):
def __init__(
self,
in_planes: int,
out_planes: int,
midplanes: int,
stride: int = 1,
padding: int = 1
) -> None:
super(Conv2Plus1D, self).__init__(
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
stride=(1, stride, stride), padding=(0, padding, padding),
Expand All @@ -52,18 +58,20 @@ def __init__(self,
bias=False))

@staticmethod
def get_downsample_stride(stride):
def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
return stride, stride, stride


class Conv3DNoTemporal(nn.Conv3d):

def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):
def __init__(
self,
in_planes: int,
out_planes: int,
midplanes: Optional[int] = None,
stride: int = 1,
padding: int = 1
) -> None:

super(Conv3DNoTemporal, self).__init__(
in_channels=in_planes,
Expand All @@ -74,15 +82,22 @@ def __init__(self,
bias=False)

@staticmethod
def get_downsample_stride(stride):
def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
return 1, stride, stride


class BasicBlock(nn.Module):

expansion = 1

def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
def __init__(
self,
inplanes: int,
planes: int,
conv_builder: Callable[..., nn.Module],
stride: int = 1,
downsample: Optional[nn.Module] = None,
) -> None:
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

super(BasicBlock, self).__init__()
Expand All @@ -99,7 +114,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
residual = x

out = self.conv1(x)
Expand All @@ -116,7 +131,14 @@ def forward(self, x):
class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
def __init__(
self,
inplanes: int,
planes: int,
conv_builder: Callable[..., nn.Module],
stride: int = 1,
downsample: Optional[nn.Module] = None,
) -> None:

super(Bottleneck, self).__init__()
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
Expand All @@ -143,7 +165,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
residual = x

out = self.conv1(x)
Expand All @@ -162,7 +184,7 @@ def forward(self, x):
class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem
"""
def __init__(self):
def __init__(self) -> None:
super(BasicStem, self).__init__(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
padding=(1, 3, 3), bias=False),
Expand All @@ -173,7 +195,7 @@ def __init__(self):
class R2Plus1dStem(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
def __init__(self) -> None:
super(R2Plus1dStem, self).__init__(
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
Expand All @@ -189,16 +211,23 @@ def __init__(self):

class VideoResNet(nn.Module):

def __init__(self, block, conv_makers, layers,
stem, num_classes=400,
zero_init_residual=False):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe block should be an nn.Module. The intention is to allow users pass their own custom blocks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, but apart from nn.Module default class, it needs an expansion attribute

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, leave as is. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
block: Type[Union[BasicBlock, Bottleneck]],
block: Callable[..., Union[BasicBlock, Bottleneck]],

conv_makers: List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int],
stem: Callable[..., nn.Module],
num_classes: int = 400,
zero_init_residual: bool = False,
) -> None:
"""Generic resnet video generator.

Args:
block (nn.Module): resnet building block
conv_makers (list(functions)): generator function for each layer
block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
function for each layer
layers (List[int]): number of blocks per layer
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
stem (Callable[..., nn.Module]): module specifying the ResNet stem.
num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
"""
Expand All @@ -221,9 +250,9 @@ def __init__(self, block, conv_makers, layers,
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type]

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)

x = self.layer1(x)
Expand All @@ -238,7 +267,14 @@ def forward(self, x):

return x

def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
planes: int,
blocks: int,
stride: int = 1
) -> nn.Sequential:
downsample = None

if stride != 1 or self.inplanes != planes * block.expansion:
Expand All @@ -257,7 +293,7 @@ def _make_layer(self, block, conv_builder, planes, blocks, stride=1):

return nn.Sequential(*layers)

def _initialize_weights(self):
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
Expand All @@ -272,7 +308,7 @@ def _initialize_weights(self):
nn.init.constant_(m.bias, 0)


def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
model = VideoResNet(**kwargs)

if pretrained:
Expand All @@ -282,7 +318,7 @@ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
return model


def r3d_18(pretrained=False, progress=True, **kwargs):
def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Resnet3D model as in
https://arxiv.org/abs/1711.11248

Expand All @@ -302,7 +338,7 @@ def r3d_18(pretrained=False, progress=True, **kwargs):
stem=BasicStem, **kwargs)


def mc3_18(pretrained=False, progress=True, **kwargs):
def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for 18 layer Mixed Convolution network as in
https://arxiv.org/abs/1711.11248

Expand All @@ -316,12 +352,12 @@ def mc3_18(pretrained=False, progress=True, **kwargs):
return _video_resnet('mc3_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
layers=[2, 2, 2, 2],
stem=BasicStem, **kwargs)


def r2plus1d_18(pretrained=False, progress=True, **kwargs):
def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for the 18 layer deep R(2+1)D network as in
https://arxiv.org/abs/1711.11248

Expand Down