-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add float cast in GeneralizedRCNN normalize #3238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @Wadaboa! Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wadaboa Thanks for the PR.
I left a couple of comments below, could you please check out?
@@ -117,6 +117,8 @@ def forward(self, | |||
return image_list, targets | |||
|
|||
def normalize(self, image): | |||
if image.dtype not in (torch.float, torch.double, torch.half): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code is valid but at TorchVision we use the following idiom:
if image.dtype not in (torch.float, torch.double, torch.half): | |
if image.is_floating_point(): |
@@ -117,6 +117,8 @@ def forward(self, | |||
return image_list, targets | |||
|
|||
def normalize(self, image): | |||
if image.dtype not in (torch.float, torch.double, torch.half): | |||
image = image.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, your code is correct but same as above:
image = image.float() | |
image = image.to(torch.float32) |
I don't know if this is the right place to ask, but I would like to know why the |
It's because the other normalize can't handle the targets. |
# check that the resulting images have float32 dtype | ||
self.assertTrue(image_list.tensors.dtype == torch.float32) | ||
# check that no NaN values are produced | ||
self.assertFalse(torch.any(torch.isnan(image_list.tensors))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wadaboa I'm still unable to see this test failing on master without your mitigation. We need to reproduce the problem, so that we are certain that we provide the right fix. Let me know your thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just went through the code that initially produced the issue for me and I realized the following: the mentioned error can only happen when the input image has uint8
dtype (so, in [0, 255] range) and the mean/std parameters are instead in [0, 1]. In this way, the output given by normalization will be all inf
(not NaN as shown in the tests).
But in this case image and mean/std would have different ranges. So, is it something that should still be checked? Or maybe I reported an error which should alert the user to check its input and this PR is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating. I agree, there are too many corner-cases to cover for uint8 inputs. This specific piece of code generally expects the input image to be floating type, so if it's not it might be worth throwing an error.
I think what we should do is close this PR to keep things clean and open a new one where you put a type check on top of normalize and throw an exception if the input is not floating point. What you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I'll close this PR and open a new one for the type check. Thank you.
This PR fixes issue #3228 by casting to float32 the input image of the
normalize
method in theGeneralizedRCNNTransform
class.Unit tests have been added, to ensure that passing
float32
oruint8
image/mean/std variables does not lead to failure.