diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a9e59a149c0..caa32505ab4 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -79,6 +79,55 @@ def get_downsample_stride(stride): return (1, stride, stride) +class Conv3DDepthwise(nn.Conv3d): + """Depthwise version of the 3D conv, + used for implementing channel-separated networks. + """ + def __init__(self, + in_planes, + out_planes, + stride=1, + padding=1): + + assert in_planes == out_planes + super(Conv3DDepthwise, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + groups=in_planes, + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class IPConv3DDepthwise(nn.Sequential): + """Depthwise version of the 3D conv, + used for implementing interaction-preserving + channel-separated networks. + """ + def __init__(self, + in_planes, + out_planes, + stride=1, + padding=1): + + assert in_planes == out_planes + super(IPConv3DDepthwise, self).__init__( + nn.Conv3d(in_planes, out_planes, kernel_size=1, bias=False), + nn.BatchNorm3d(out_planes), + nn.ReLU(inplace=True), + Conv3DDepthwise(out_planes, out_planes, None, stride) + ) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + class BasicBlock(nn.Module): expansion = 1 @@ -339,3 +388,45 @@ def r2plus1d_18(pretrained=False, progress=True, **kwargs): conv_makers=[Conv2Plus1D] * 4, layers=[2, 2, 2, 2], stem=R2Plus1dStem, **kwargs) + + +def ir_csn_152(pretrained=False, progress=False, **kwargs): + """Constructor for the 152 layer deep ir-CSN network as described + in https://arxiv.org/abs/1904.02811. + Note that video model zoo (https://github.com/facebookresearch/VMZ) provides + models pretrained on large scale benchmarks such as Sports1M and URU. + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: ir-CSN-152 network + """ + return _video_resnet('ir_csn_152', + False, False, + block=Bottleneck, + conv_makers=[Conv3DDepthwise] * 4, + layers=[3, 8, 36, 3], + stem=BasicStem, **kwargs) + + +def ip_csn_152(pretrained=False, progress=False, **kwargs): + """Constructor for the 152 layer deep ip-CSN network as described + in https://arxiv.org/abs/1904.02811. + Note that video model zoo (https://github.com/facebookresearch/VMZ) provides + models pretrained on large scale benchmarks such as Sports1M and URU. + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: ip-CSN-152 network + """ + return _video_resnet('ip_csn_152', + False, False, + block=Bottleneck, + conv_makers=[IPConv3DDepthwise] * 4, + layers=[3, 8, 36, 3], + stem=BasicStem, **kwargs)