Skip to content

Add float cast in GeneralizedRCNN normalize #3238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Add float cast in GeneralizedRCNN normalize #3238

wants to merge 3 commits into from

Conversation

Wadaboa
Copy link
Contributor

@Wadaboa Wadaboa commented Jan 11, 2021

This PR fixes issue #3228 by casting to float32 the input image of the normalize method in the GeneralizedRCNNTransform class.

Unit tests have been added, to ensure that passing float32 or uint8 image/mean/std variables does not lead to failure.

@facebook-github-bot
Copy link

Hi @Wadaboa!

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wadaboa Thanks for the PR.

I left a couple of comments below, could you please check out?

@@ -117,6 +117,8 @@ def forward(self,
return image_list, targets

def normalize(self, image):
if image.dtype not in (torch.float, torch.double, torch.half):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code is valid but at TorchVision we use the following idiom:

Suggested change
if image.dtype not in (torch.float, torch.double, torch.half):
if image.is_floating_point():

@@ -117,6 +117,8 @@ def forward(self,
return image_list, targets

def normalize(self, image):
if image.dtype not in (torch.float, torch.double, torch.half):
image = image.float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, your code is correct but same as above:

Suggested change
image = image.float()
image = image.to(torch.float32)

@Wadaboa
Copy link
Contributor Author

Wadaboa commented Jan 11, 2021

I don't know if this is the right place to ask, but I would like to know why the normalize function (in the GeneralizedRCNNTransform class) is re-implemented from scratch if it is already available in torchvision.transforms.functional.normalize. Is there a reason or could it be substituted?

@datumbox
Copy link
Contributor

I don't know if this is the right place to ask, but I would like to know why the normalize function (in the GeneralizedRCNNTransform class) is re-implemented from scratch if it is already available in torchvision.transforms.functional.normalize. Is there a reason or could it be substituted?

It's because the other normalize can't handle the targets.

# check that the resulting images have float32 dtype
self.assertTrue(image_list.tensors.dtype == torch.float32)
# check that no NaN values are produced
self.assertFalse(torch.any(torch.isnan(image_list.tensors)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wadaboa I'm still unable to see this test failing on master without your mitigation. We need to reproduce the problem, so that we are certain that we provide the right fix. Let me know your thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just went through the code that initially produced the issue for me and I realized the following: the mentioned error can only happen when the input image has uint8 dtype (so, in [0, 255] range) and the mean/std parameters are instead in [0, 1]. In this way, the output given by normalization will be all inf (not NaN as shown in the tests).

But in this case image and mean/std would have different ranges. So, is it something that should still be checked? Or maybe I reported an error which should alert the user to check its input and this PR is not needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating. I agree, there are too many corner-cases to cover for uint8 inputs. This specific piece of code generally expects the input image to be floating type, so if it's not it might be worth throwing an error.

I think what we should do is close this PR to keep things clean and open a new one where you put a type check on top of normalize and throw an exception if the input is not floating point. What you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll close this PR and open a new one for the type check. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants