From 7259d915fb5c29f143f9534a28a0a4d65df603f0 Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:00:31 +0530 Subject: [PATCH 1/3] Internal Imagenet normalisation for pretrained resnet models Consistent with inceptionV3 and googleNet pytorch implementation. --- torchvision/models/resnet.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 551cd784103..5506bf3bdc8 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -74,8 +74,17 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): self.stride = stride def forward(self, x): + + #imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + identity = x + out = self.conv1(x) out = self.bn1(out) out = self.relu(out) From 2fba34e18d8f355de4f191438aaa735eeeefcace Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:04:28 +0530 Subject: [PATCH 2/3] Update resnet.py --- torchvision/models/resnet.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 5506bf3bdc8..4f3ff33993d 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -75,16 +75,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): def forward(self, x): - #imagenet normalisation - if self.transform_input: - x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 - x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 - x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 - x = torch.cat((x_ch0, x_ch1, x_ch2), 1) - identity = x - out = self.conv1(x) out = self.bn1(out) out = self.relu(out) @@ -107,8 +99,9 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, tranform_input=False): super(ResNet, self).__init__() + self.tranform_input = tranform_input self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) @@ -156,6 +149,14 @@ def _make_layer(self, block, planes, blocks, stride=1): return nn.Sequential(*layers) def forward(self, x): + + #imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.conv1(x) x = self.bn1(x) x = self.relu(x) From 23e0618809bc3ac43fa0ffed6dd80e8fb99ad8f6 Mon Sep 17 00:00:00 2001 From: ekka Date: Sat, 9 Mar 2019 01:31:15 +0530 Subject: [PATCH 3/3] fixed F821 --- torchvision/models/resnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 4f3ff33993d..548361aeebc 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo @@ -150,7 +151,7 @@ def _make_layer(self, block, planes, blocks, stride=1): def forward(self, x): - #imagenet normalisation + # imagenet normalisation if self.transform_input: x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224