diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index a3e51e3b953..89316a59c3e 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -38,8 +38,9 @@ def forward(self, x): class SqueezeNet(nn.Module): - def __init__(self, version=1.0, num_classes=1000): + def __init__(self, version=1.0, num_classes=1000, transform_input=False): super(SqueezeNet, self).__init__() + self.transform_input = transform_input if version not in [1.0, 1.1]: raise ValueError("Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)) @@ -95,6 +96,14 @@ def __init__(self, version=1.0, num_classes=1000): init.constant_(m.bias, 0) def forward(self, x): + + # imagenet normalisation + if self.transform_input: + x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229 + x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224 + x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + x = self.features(x) x = self.classifier(x) return x.view(x.size(0), self.num_classes)