Skip to content

Error in mean and std computation in torchvision.models #3657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
krishnap25 opened this issue Apr 10, 2021 · 3 comments
Closed

Error in mean and std computation in torchvision.models #3657

krishnap25 opened this issue Apr 10, 2021 · 3 comments

Comments

@krishnap25
Copy link

krishnap25 commented Apr 10, 2021

📚 Documentation

Hello all,

Thanks for the fantastic library. I would like to point out an error in the mean and std computation in the torchvision.models page. In particular, I'm referring to the code following the line
"The process for obtaining the values of mean and std is roughly equivalent to:"

import torch
from torchvision import datasets, transforms as T

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)

means = []
stds = []
for img in subset(dataset):
    means.append(torch.mean(img))
    stds.append(torch.std(img)) # Bug here 

mean = torch.mean(torch.tensor(means)) # Error here
std = torch.mean(torch.tensor(stds)) # Error here

There are two issues here:

  • torch.tensor(means) throws an error: ValueError: only one element tensors can be converted to Python scalars. It should be torch.stack(means)
  • the mean of the standard deviations should rather be a mean of the variances, followed by square rooting it at the end.
    Here is a version which fixes these:
import torch
from torchvision import datasets, transforms as T

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)

means = []
variances = []
for img in subset(dataset):
    means.append(torch.mean(img))
    variances.append(torch.std(img)**2)

mean = torch.mean(torch.stack(means), axis=0)
std = torch.sqrt(torch.mean(torch.stack(variances), axis=0))

I would argue further that since we are interested in the per-channel mean and std, we should first compute the mean across all images and then compute the std using this "batch mean" (whereas the current version uses the per-channel mean of the image). I have not run these on a large dataset, so I do not know how much of a difference this would make.

Thank you!

@pmeier
Copy link
Collaborator

pmeier commented Apr 11, 2021

Hey @krishnap25,

torch.tensor(means) throws an error

This is why I used the word "roughly" in the description. This is meant as pseudo-code on how this was computed. I'm okay with adding a torch.stack to make it executable.

the mean of the standard deviations should rather be a mean of the variances, followed by square rooting it at the end.

True, but nothing we can do about it. Unless you have the resources and are willing to offer them for free, retraining all the models to change this is not an option. Believe me, I tried to argue the same thing in #1439. You can find the script I used for figuring out which approach was most likely used here. I encourage you to try your (an objectively better) approach and see if the number change all that much. If that is the case and I don't believe that maybe we can reopen the discussion.

Still, the issue is not actionable. Thus, I'm closing it. Let me know if you have other questions about this.

@krishnap25
Copy link
Author

Hey Philip, thanks for your response! Your investigations are interesting.

I came across this issue because tried computing the mean/std for a different dataset. Would it help to include a comment in the documentation for future users?

@pmeier
Copy link
Collaborator

pmeier commented Apr 12, 2021

Sure, do you want to send a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants