From 48989fc74d5dfcd598029d317f5373430839250b Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Tue, 8 Oct 2019 08:09:47 -0700 Subject: [PATCH 1/4] [model update] initial commit --- torchvision/models/video/resnet.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a9e59a149c0..1c837758dfe 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,3 +1,5 @@ +import warnings + import torch import torch.nn as nn @@ -17,7 +19,6 @@ class Conv3DSimple(nn.Conv3d): def __init__(self, in_planes, out_planes, - midplanes=None, stride=1, padding=1): @@ -39,9 +40,11 @@ class Conv2Plus1D(nn.Sequential): def __init__(self, in_planes, out_planes, - midplanes, stride=1, padding=1): + + midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( + in_planes * 3 * 3 + 3 * out_planes) super(Conv2Plus1D, self).__init__( nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), @@ -62,7 +65,6 @@ class Conv3DNoTemporal(nn.Conv3d): def __init__(self, in_planes, out_planes, - midplanes=None, stride=1, padding=1): @@ -84,16 +86,15 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) super(BasicBlock, self).__init__() self.conv1 = nn.Sequential( - conv_builder(inplanes, planes, midplanes, stride), + conv_builder(inplanes, planes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes), + conv_builder(planes, planes), nn.BatchNorm3d(planes) ) self.relu = nn.ReLU(inplace=True) @@ -120,7 +121,6 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): super(Bottleneck, self).__init__() - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) # 1x1x1 self.conv1 = nn.Sequential( @@ -130,7 +130,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): ) # Second kernel self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes, stride), + conv_builder(planes, planes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) @@ -190,6 +190,10 @@ def __init__(self): class VideoResNet(nn.Module): + # Version 2 adds updated BN params, and + # solves midplane computation + _version = 2 + def __init__(self, block, conv_makers, layers, stem, num_classes=400, zero_init_residual=False): @@ -268,6 +272,9 @@ def _initialize_weights(self): elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + # following are v2 updates for maximum reproducibility + m.eps = 1e-3 + m.momentum = 0.9 elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) @@ -333,6 +340,12 @@ def r2plus1d_18(pretrained=False, progress=True, **kwargs): Returns: nn.Module: R(2+1)D-18 network """ + warnings.warn( + "This is an updated vesrion of the R(2+1D) model that was " + "updated following discussion in #1265. The performance " + "deviations are minimal, but this might cause some BW compatibility " + "issues, depending on the models.", + UserWarning) return _video_resnet('r2plus1d_18', pretrained, progress, block=BasicBlock, From e897ffae223c7bc0b2fb2bde9928a286e398ae66 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 24 Oct 2019 12:04:11 -0700 Subject: [PATCH 2/4] new load_state_dict --- torchvision/models/video/resnet.py | 43 +++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 1c837758dfe..1733f9da3d2 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -43,15 +43,15 @@ def __init__(self, stride=1, padding=1): - midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( + self.midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( in_planes * 3 * 3 + 3 * out_planes) super(Conv2Plus1D, self).__init__( - nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), + nn.Conv3d(in_planes, self.midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), bias=False), nn.BatchNorm3d(midplanes), nn.ReLU(inplace=True), - nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), + nn.Conv3d(self.midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False)) @@ -278,6 +278,37 @@ def _initialize_weights(self): elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get("version", None) + assert version in [1, 2] + + # the new changes only apply to the R2+1D models + if version == 1 and isinstance(self.layer2[0].conv2[0], Conv2Plus1D): + # V1 of the models had midplanes hard coded into the blocks + # and default BN parameters as in Pytorch. + # All other layer configurations were the same. + self.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 230) + self.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 460) + self.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 921) + for m in self.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-5 + m.momentum = 0.1 + + # The model is now identical to v1, and must be saved as such. + self._version = 1 + warnings.warn( + "This is an updated vesrion of the R(2+1D) model that was " + "updated following discussion in #1265. The performance " + "deviations are minimal, but this might cause some BW compatibility " + "issues, depending on the models.", + UserWarning) + + super(VideoResNet, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) def _video_resnet(arch, pretrained=False, progress=True, **kwargs): @@ -340,12 +371,6 @@ def r2plus1d_18(pretrained=False, progress=True, **kwargs): Returns: nn.Module: R(2+1)D-18 network """ - warnings.warn( - "This is an updated vesrion of the R(2+1D) model that was " - "updated following discussion in #1265. The performance " - "deviations are minimal, but this might cause some BW compatibility " - "issues, depending on the models.", - UserWarning) return _video_resnet('r2plus1d_18', pretrained, progress, block=BasicBlock, From df8316802e6f9b8c8a7afbadfaeb0f241b847892 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 24 Oct 2019 12:31:21 -0700 Subject: [PATCH 3/4] Testing everything out --- torchvision/models/video/resnet.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 1733f9da3d2..319ea954f81 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -41,17 +41,19 @@ def __init__(self, in_planes, out_planes, stride=1, - padding=1): + padding=1, + midplanes=None): - self.midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( + if midplanes is None: + midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( in_planes * 3 * 3 + 3 * out_planes) super(Conv2Plus1D, self).__init__( - nn.Conv3d(in_planes, self.midplanes, kernel_size=(1, 3, 3), + nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), bias=False), nn.BatchNorm3d(midplanes), nn.ReLU(inplace=True), - nn.Conv3d(self.midplanes, out_planes, kernel_size=(3, 1, 1), + nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False)) @@ -283,15 +285,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get("version", None) assert version in [1, 2] - # the new changes only apply to the R2+1D models if version == 1 and isinstance(self.layer2[0].conv2[0], Conv2Plus1D): # V1 of the models had midplanes hard coded into the blocks # and default BN parameters as in Pytorch. # All other layer configurations were the same. - self.layer2[0].conv2[0] = Conv2Plus1D(128, 128, 230) - self.layer3[0].conv2[0] = Conv2Plus1D(256, 256, 460) - self.layer4[0].conv2[0] = Conv2Plus1D(512, 512, 921) + self.layer2[0].conv2[0] = Conv2Plus1D(128, 128, midplanes=230) + self.layer3[0].conv2[0] = Conv2Plus1D(256, 256, midplanes=460) + self.layer4[0].conv2[0] = Conv2Plus1D(512, 512, midplanes=921) + for m in self.modules(): if isinstance(m, nn.BatchNorm3d): m.eps = 1e-5 @@ -303,8 +305,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, "This is an updated vesrion of the R(2+1D) model that was " "updated following discussion in #1265. The performance " "deviations are minimal, but this might cause some BW compatibility " - "issues, depending on the models.", - UserWarning) + "issues, depending on the models.", UserWarning) super(VideoResNet, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, From 71525c6dcb0a7554ed52e0c17026bced5c56fc48 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Tue, 31 Dec 2019 16:17:28 +0000 Subject: [PATCH 4/4] typo fix --- torchvision/models/video/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 319ea954f81..40875a8c86f 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -302,7 +302,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # The model is now identical to v1, and must be saved as such. self._version = 1 warnings.warn( - "This is an updated vesrion of the R(2+1D) model that was " + "This is an updated version of the R(2+1D) model that was " "updated following discussion in #1265. The performance " "deviations are minimal, but this might cause some BW compatibility " "issues, depending on the models.", UserWarning)