diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 551cd784103..548361aeebc 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo @@ -74,6 +75,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): self.stride = stride def forward(self, x): + identity = x out = self.conv1(x) @@ -98,8 +100,9 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, tranform_input=False): super(ResNet, self).__init__() + self.tranform_input = tranform_input self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) @@ -147,6 +150,14 @@ def _make_layer(self, block, planes, blocks, stride=1): return nn.Sequential(*layers) def forward(self, x): + + # imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.conv1(x) x = self.bn1(x) x = self.relu(x)