Skip to content

Internal Imagenet normalisation for pretrained Densenet models #782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Internal Imagenet normalisation for pretrained Densenet models #782

wants to merge 3 commits into from

Conversation

ekagra-ranjan
Copy link
Contributor

Added the argument transform_input during the calling of densenet models for automatic normalisation of inputs with imagenet mean and std for easier use. This is consistent with the inceptionV3 pytorch implementation.

…3 implementation

Added the argument transform_input during the calling of densenet models for automatic normalisation of inputs with imagenet mean and std for easier use. This is consistent with the inceptionV3 pytorch implementation.
@pmeier
Copy link
Collaborator

pmeier commented Mar 9, 2019

Since normalisation corresponds more to the data rather than the model itself, this functionality is already included within the predefined torchvision.datasets. For example, check out the official PyTorch beginner tutorial:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

It uses the transforms.Normalize to take care of the normalisation for every instance of the dataset. Of course you should change the mean and std to the correct values:

transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

Thus, IMO your change is not needed. This also applies to the other PRs (#783 #784 #785 #786) implementing the same functionality for other predefined models.

@surgan12
Copy link
Contributor

surgan12 commented Mar 9, 2019

@pmeier I agree with you , as normalization is something to do with the dataset and shouldn't be incorporated in the model.

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Mar 9, 2019

@pmeier You are right that this can be done explicitly using transfrom.normalize() function, but many beginners don't use the imagenet mean and std to normalise their data when using transfer learning from imagenet weights. Most common mistake by starters is to normalise the data using min max normalisation where you will get the data between 0 and 1 rather that -1 and 1.

Plus it would be easier for the pytorch users not to remember that googlenet and inceptionV3 are not to be normalised explicitly while other imagnet models must be. It would be better that there is consistency in the implementation and use of all the imagenet models.

@surgan12 The normalisation is not independent of the model if you are using transfer learning. Not using imagenet normalisation (by default) while using imagenet pretrained weights will lead to slower convergence of the model as the models are expecting the images to come from the distribution having mean and std as that of imagenet on which it was thoroughly trained.

cc: @fmassa

@pmeier
Copy link
Collaborator

pmeier commented Mar 9, 2019

@ekagra-ranjan
I think we can expect new users to read the documentation before they use the predefined models. It clearly states:

The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. (emphasis added)

Furthermore the transfer learning tutorial (which seems to be your field of operation) also uses this way to normalise data.


Plus it would be easier for the pytorch users not to remember that googlenet and inceptionV3 are not to be normalised explicitly while other imagnet models must be.

Do you have a source for this? The documentation states

All pre-trained models expect input images normalized in the same way [...] (emphasis added)

AFAIK all models are trained with this procedure, which uses the above normalisation.


The normalisation is not independent of the model if you are using transfer learning.

Why not? It doesn't make a difference if you first process the data and forward this to the model or do the processing as first step within the model. However, the former is cleaner as explained above.

If you are still eager to do this normalisation within the model you could simply define a module

import torch
from torch import nn


class ImageNetNormalization(nn.Module):
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        super(ImageNetNormalization, self).__init__()
        mean = torch.tensor(mean).view(1, -1, 1, 1)
        self.register_parameter('mean', nn.Parameter(mean, requires_grad=False))
        std = torch.tensor(std).view(1, -1, 1, 1)
        self.register_parameter('std', nn.Parameter(std, requires_grad=False))

    def forward(self, input):
        return (input - self.mean) / self.std

and bundle it with your model in a nn.Sequential

from torchvision.models import densenet121

model = nn.Sequential(
    ImageNetNormalization(),
    densenet121(pretrained=True),
)

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Mar 9, 2019

@pmeier

Do you have a source for this? The documentation states
All pre-trained models expect input images normalized in the same way [...] (emphasis added)
AFAIK all models are trained with this procedure, which uses the above normalisation.

Yes, I have a source. Please have a look at the inceptionV3 (here) and googlenet (here) implementation.


Why not? It doesn't make a difference if you first process the data and forward this to the model or do the processing as first step within the model. However, the former is cleaner as explained above.

By model independence I didn't mean the above. When one uses transfer learning, one expects that the model weights are to be loaded by the model function on its own without the user interference. Similarly one can also allow the model to do the preprocessing on its own as it is an essential part of transfer learning of the model and depends on the type of the model use (different for inceptionV3 and googlenet).

@TheCodez
Copy link
Contributor

TheCodez commented Mar 9, 2019

@ekagra-ranjan GoogLeNet and InceptionV3 still assume the input to be normalized using the ImageNet mean and std. Those lines of code just convert our normalization to the Tensorflow normalization (the weights were copied from TensorFlow).
Hope that clears things up :)

@ekagra-ranjan
Copy link
Contributor Author

@TheCodez Thanks for the insight :) The formulation indeed reverses the normalisation from imagenet mean and std. InceptionV3 and googlenet normalise with different mean compared with rest of imagenet models. Wouldn't it be clearer and more intuitive if we directly normalise directly without reversing any unwanted normalisation? This can be done by providing the option to normalise the input automatically by the model by passing the argument transform_input=True. Then we can directly normalise with respective mean and std which would make it easier to relate the normalistion done in paper and the implemenetation.

Included the argument `transform_input` in the docs of densenets
@codecov-io
Copy link

codecov-io commented Mar 9, 2019

Codecov Report

Merging #782 into master will decrease coverage by 0.07%.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #782      +/-   ##
==========================================
- Coverage   38.13%   38.05%   -0.08%     
==========================================
  Files          32       32              
  Lines        3126     3132       +6     
  Branches      487      488       +1     
==========================================
  Hits         1192     1192              
- Misses       1855     1861       +6     
  Partials       79       79
Impacted Files Coverage Δ
torchvision/models/densenet.py 17.88% <0%> (-0.92%) ⬇️
torchvision/models/googlenet.py 15.87% <0%> (ø) ⬆️
torchvision/models/inception.py 14.41% <0%> (ø) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9d9f48a...9f3140d. Read the comment docs.

@fmassa
Copy link
Member

fmassa commented Mar 9, 2019

Hi,

Thanks for the series of PRs adding support for this!

My first thoughts were aligned with @pmeier and the others: I would consider Inception and GoogleNet to be special because they haven't been trained with the same normalization as the others, and thus for consistency it makes sense to perform this conversion step inside of those models.

Thus, I'd rather not merge those changes for now. It embeds inside the model implementation some magical constants that have nothing to do with the model, but only with the pre-trained model on a particular dataset.


That being said, here is another data point: torchvision datasets and models perform data transformation on the Dataset (inside the dataloader, and thus potentially using multiple threads). This works great in many cases, but for some other tasks (like video decoding / resizing), performing the transformations on the GPU is desired/necessary for performance. But the DataLoader with multiple threads doesn't work that well when the transforms should be on the GPU. In those cases, having all the transforms being part of the model, and running on the GPU, might be the best solution.

A solution like #782 (comment) is probably in the direction of what I'll be thinking about doing. In this case, if we make the mean / std as buffers and not parameters, and they will be serialized together with the model parameters via state_dict, which means that we don't need to carry the imagenet mean/std with us whenever we want to use this code..

As I still haven't figured out all the details on how this could be done, I'd rather keep such things for later.

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Mar 9, 2019

@fmassa For the current implementation of inceptionV3 if the input is not normalised with ImageNet mean and std then the input range becomes [-0.188, 0.428] rather than [-1, 1], which might be unintended for some users beacuse they are unaware of the magical constants. If we directly normalise instead of reverting the ImageNet normalisation first and then applying the inceptionV3 normalisation, we will get the range from [-1, 1] always.

Allowing the use of transform_input will also allow future models with different normalisation techniques to be easily incorporated. This will prevent the new users, reading the docs, from having the false assumptions that all ImageNet models are normalised with same mean and std regardless of their different normalisation strategies mentioned in their papers. IMO this will make the code more cleaner and directly relatable to the literature. Thoughts?

With this implementation we are also allowing normalisation to be done on the GPU if the model is pushed on GPU. So wouldn't this also solve the issue of not being able to normalise on GPU with dataloaders effectively?

@fmassa
Copy link
Member

fmassa commented Mar 9, 2019

@fmassa For the current implementation of inceptionV3 if the input is not normalised with ImageNet mean and std then the input range becomes [-0.188, 0.428] rather than [-1, 1], which might be unintended for some users beacuse they are unaware of the magical constants. If we directly normalise instead of reverting the ImageNet normalisation first and then applying the inceptionV3 normalisation, we will get the range from [-1, 1] always.

Yes, but at least all models in torchvision can use the same normalization. None of the pre-trained models will work if we don't use the imagenet mean/std

Allowing the use of transform_input will also allow future models with different normalisation techniques to be easily incorporated. This will prevent the new users, reading the docs, from having the false assumptions that all ImageNet models are normalised with same mean and std regardless of their different normalisation strategies mentioned in their papers. IMO this will make the code more cleaner and directly relatable to the literature. Thoughts?

I don't follow it. All models in torchvision, from an user-perspective, should be normalized the same. Isn't that the case?

With this implementation we are also allowing normalisation to be done on the GPU if the model is pushed on GPU. So wouldn't this also solve the issue of not being able to normalise on GPU with dataloaders effectively?

Yes, but I'd like to think a bit more about how to perform this in a better way.
For example, all models have the same (or almost) transformations (including the normalization). This could be a layer that we could add to the model, but I'm not clear yet if this is the best way to address this.

@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented Mar 10, 2019

@fmassa
InceptionV3 and GoogLeNet are normalised with mean = (0.5, 0.5, 0.5) and std = (0.5, 0.5, 0.5) whereas other models are normalised with mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225).

Let me try to explain more clearly with the new code:

Densenet:

Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
    transform_input (bool): If True, preprocesses the input according to the method with which it
    was trained on ImageNet. Default: *False*
    def forward(self, x):

        # imagenet normalisation
        if self.transform_input:
            x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.485) / 0.229
            x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.456) / 0.224
            x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.406) / 0.225
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
        out = self.classifier(out)
        return out

googLeNet:

Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
    transform_input (bool): If True, preprocesses the input according to the method with which it
    was trained on ImageNet. Default: *False*
def forward(self, x):
        if self.transform_input:
            x_ch0 = (torch.unsqueeze(x[:, 0], 1) - 0.5 ) / 0.5
            x_ch1 = (torch.unsqueeze(x[:, 1], 1) - 0.5 ) / 0.5 
            x_ch2 = (torch.unsqueeze(x[:, 2], 1) - 0.5 ) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)
       ...
       ...
       ...

The user calls densnet121 as:

model = torchvision.models.densenet121(pretrained=True, transform_input=True)

and googLeNet as:

model = torchvision.models.googlenet(pretrained=True, transform_input=True)

Yes, but at least all models in torchvision can use the same normalization. None of the pre-trained models will work if we don't use the imagenet mean/std

The new implementation does not discard the use of ImageNet mean/std. It allows consistent method for normalisation from user-perspective while reflecting in the code the use of different techniques used by the authors during training on ImageNet.

I don't follow it. All models in torchvision, from an user-perspective, should be normalized the same. Isn't that the case?

Using transform_input argument also allows consistent normalisation from user-perspective. All ImageNet models don't use the same normalisation technique so it would be helpful for users if it's reflected in the code. The reason why we would want torchvision models to be having the same normalisation is to make normalisation consistent among the models for easier use.This indeed can be done using the transform_input argument which would make the normalisation consistent from users' perspective while allowing normalisation to be done with their respective techniques under the hood. The current implementation of normalisation can give the users reading the docs a false intuition that all the ImageNet models are normalised in the same manner (which is not true). This will also make the implementation more robust to the inclusion of models with different normalisation strategies in the future.

@youkaichao
Copy link
Contributor

I stand for the author. Although we typically think normalization is specific to each dataset, preprocessing should be integrated with the pretrained model.

Say that a model is trained using a strange normalization : images are added by 1.0. Such a pretrained model would make no sense if the input is raw images.

That said, a model is a function mapping the input to the output. And a pretrained model maps a low-dimensional manifold in the input space to the corresponding output. Different normalizations make the input lie in a different manifold, leaving it to deviate its working manifold and yielding nonsense features.

So I urge to integrate the normalization to the pretrained model. However, if a model is trained from scratch, it makes no sense to keep these magic numbers.

@ekagra-ranjan I think transform_input and pretrained should be the same. If pretrained is set to beTrue, then we should normalize the input automatically.

@ekagra-ranjan
Copy link
Contributor Author

That said, a model is a function mapping the input to the output. And a pretrained model maps a low-dimensional manifold in the input space to the corresponding output. Different normalizations make the input lie in a different manifold, leaving it to deviate its working manifold and yielding nonsense features.

@youkaichao Thank you for insights!

I think transform_input and pretrained should be the same. If pretrained is set to beTrue, then we should normalize the input automatically.

I too was in the favor of merging these 2 arguments but then I came across a scenario when dealing with non-natural images like medical imaging. For eg: for classifying chest xays, one might use the pretrained weights but normalise on the mean and std of the chest xray dataset beacuse normalising chest xrays with mean and std of natural images might not lead to mean ~ 0 and std ~ 1 data (although no-one can certainly tell without trying both the normalisation approcahes :) ). So I thought it would give the user more flexibility by keeping them as 2 seperate arguments so that they can be set according to their experiments although in majority cases both the arguments will be having the same value.

@pmeier
Copy link
Collaborator

pmeier commented Mar 11, 2019

@youkaichao I think making the normalisation mandatory is the wrong way. @fmassa provided a good argument for when it should be integrated into the model for performance reasons.

I think what was omitted in this discussion until now is that normalisation is not the only preprocessing step that is needed.

  1. Unless your data is natively a torch.Tensor you will need the ToTensor transformation and cannot pass your data directly into the model.
  2. Before the data is normalised it also has to be scaled into the range [0, 1], since uint8 images are usually in the range [0, 255]. I'm aware that this is automatically performed by the ToTensor transformation, but I wanted to be clear that this is a separate step.

Following your argumentation these steps should also be included into the model, which splits the interface depending on the pretrained flag.

Since @ekagra-ranjan argued that the current way could confuse new users, we could bundle these two transforms together:

from torchvision import transforms
import torchvision.transforms.functional as F

class VisionPreprocessing(object):
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, pic):
        tensor = F.to_tensor(pic)
        return F.normalize(tensor, self.mean, self.std, self.inplace)

    def __repr__(self):
        s = "mean={mean}, std={std}".format(self.__dict__)
        if self.inplace:
            s += ", inplace=True"
        return "{}({})".format(self.__class__.__name__, s)

With this we could simply do:

dataset = torchvision.datasets.*(*, transform=transforms.VisionPreprocessing())

@fmassa
Copy link
Member

fmassa commented Mar 11, 2019

@ekagra-ranjan if you have a look at the current implementation of the normalization in GoogleNet / Inception

x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5

you'll see that it is doing some internal conversion to make it from the ImageNet-normalized to the 0.5 normalized input. Plus, if you look at
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True

you'll see that if the user requests a pre-trained model, it will force the transform_input to be True. This means that, from an user perspective, all pre-trained models in torchvision expect the inputs to be normalized in the same way.

I totally agree that we should embed somehow the pre-processing to be part of the model parameters. I just don't think that the current way implemented in this PR is the best way to go.

Here is one possible solution (which is not perfect, and I'll explain why in a second):

class ResNet(...):
    def __init__(self, ..., data_transform=None):
        ...
        self.data_transform = data_transform

    def forward(self, x):
        if self.data_transform is not None:
            x = self.data_transform(x)
        ...  # as before

class DataTransform(nn.Module):
    # all the pre-processing that we want, including
    # weight / std as a buffer
    def __init__(self, ...):
        ...
        self.register_buffer("size", torch.tensor(224))
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406])[:, None, None])

    def forward(self, x):
       x = to_tensor(x)
       if self.training:
            x = random_resize_crop(x, self.size)
       else:
            x = resize(x, self.size)
       x = (x - self.mean) / self.std  # normalize
       return x

This would embed *all the transformations into the model, making it very clear what should be done to get the model running. Plus, the mean/std/size are saved together with the model parameters, which is what we want.

Now, the problem with this approach is that loading the model now requires a particular set of DataTransform to be present. If we don't have it, it fails because we wouldn't be able to match exactly the buffers for DataTransform. There are ways of overcoming this, but as I said before I'd like to give this a second thought before implementing this.

@ekagra-ranjan
Copy link
Contributor Author

Okay @fmassa, as you say. If you come up with anything about this in future and don't have time to implement then feel free to instruct me.

@fmassa
Copy link
Member

fmassa commented Mar 11, 2019

@ekagra-ranjan l'm actively looking for a better approach for what I've mentioned above, so if you have another solution, please let me know! :-)

@youkaichao
Copy link
Contributor

@fmassa I have an idea, but it is somewhat ground-breaking: The preprocess should be divided into two parts: data-augmentation and real preprocessing. With data-augmentation, we usually perform resize, flip, crop. But with real preprocessing, we can perform ToTensor, normalize etc. Data-augmentation should be implemented in Dataset but preprocess should be implemented in the Module. This way, computation-bound tasks like normalize can be performed in GPU and preprocess parameters can be serialized.

@fmassa
Copy link
Member

fmassa commented Mar 11, 2019

@youkaichao why not perform data_augmentation on the model itself, inside the pre-processing block?

@youkaichao
Copy link
Contributor

@fmassa because conceptually, a model takes a natural image as its input and output something. But data_augmentation is a process that takes a natural image and randomly outputs a transformed natural image.

By saying a transformed natural image, I mean an array that can be passed into mathplotlib.pyplot.imshow with some reshape operations.

To make things more clear, data_augmentation usually can be carried out by PIL. You see, PIL works in the manifold of natural images and there is no normalization stuff in PIL.

But maybe we'd better return a Tensor in Dataset, so the pipeline may look like:
|--------------------------data_augmentation----------------|
Dataset. getitem --> PIL.Image.Image --> PIL.Image.Image --> ToTensor[1] -->

|-----------------------------the model ------------|
| ----------preprocessing---|
torch.Tensor --> torch.Tensor --> forward --> output

As the ToTensor lies between data_augmentation and the model, it can be assigned into either process. But to be consistent with current code (and make batching in Dataloader easier), ToTensor can be integrated into data_augmentation.

[1]: colored images would correspond to FloatTensor with shape of [3, H, W] and values lying in [0, 1], gray images would correspond to FloatTensor with shape of [H, W] and values lying in [0, 1]

@fmassa
Copy link
Member

fmassa commented Mar 26, 2019

@youkaichao what you say makes a lot of sense.
For efficiency though, we indeed need to return torch.Tensor from the dataset, as you proposed.

I'm going to give this a second thought, but this is an interesting idea to pursue

@rwightman
Copy link
Contributor

Ran into this looking for something else. I also agree that the norm should not be baked into the model, as the same models can be used on many different datasets or even the same ones with different norm types (ie per image) or mean/std.

In the past, when I've wanted GPU based norm or preprocessing ops that are created with the model, I write a class like fmassa' DataTransform, but I don't modify the model to accept it, I just wrap the model model = Transformer(model).

Usually though, my norm transforms are done as part of the data loader/ prefetch in Nvidia style https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py#L256

@pmeier
Copy link
Collaborator

pmeier commented Apr 12, 2019

@rwightman Indeed, you are right in the general case. If you are not working with the pretrained models, it makes little sense to fix the normalisation (or preprocessing in general) to the model.
If, on the other hand, you do work with the pretrained models, you are forced to use the same normalisation that the models were trained on to obtain sensible results.


The longer I think about this, the more I tend to agree with @youkaichao. Maybe we can add a preprocessing nn.Module in case the user requests pretrained weights. Since the normalisation is a rather simple operation, we can create the module at runtime. There is no need to serialise it with the rest of the model. Thus, we keep backward compatibility with little effort.

@ain-soph
Copy link
Contributor

ain-soph commented Mar 8, 2021

Is there any further work since 2019?

I always suffer from this issue for a long time (actually from my first day PhD 2333). I'm doing security research, and need to manipulate the images (e.g. PGD adversarial attack). But of course, the manipulated image still need to be valid images ([0,1] and even more: slices of 1/255).

However, my data is after normalized (not [0,1]), so I have to consider it in my scripts. If the Normalization is embedded into the model structure and we feed images with [0,1] into the forward function, it would be extremely sweet! My current codes have adopt this strategy and Normalize is the first Module of my Model.

From the data process perspective, we are always doing process on raw images with [0,1] and never gonna make any change on normalized images. It would be more convenient if the data fetched from dataloader is [0,1].

In conclusion, the normalization is part of the model, not part of the data process. We would never use the normalized data out of the model.

@malfet malfet deleted the branch pytorch:master September 20, 2021 14:34
@malfet malfet closed this Sep 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants