-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add float cast in GeneralizedRCNN normalize #3238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,34 @@ def test_transform_copy_targets(self): | |
self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) | ||
self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) | ||
|
||
def test_normalize_integer(self): | ||
transform = GeneralizedRCNNTransform( | ||
300, 500, | ||
torch.randint(0, 255, (3,)), | ||
torch.randint(0, 255, (3,)) | ||
) | ||
image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] | ||
targets = [{'boxes': torch.randint(0, 255, (3, 4))}] | ||
image_list, _ = transform(image, targets) # noqa: F841 | ||
# check that original images still have uint8 dtype | ||
for img in image: | ||
self.assertTrue(img.dtype == torch.uint8) | ||
# check that the resulting images have float32 dtype | ||
self.assertTrue(image_list.tensors.dtype == torch.float32) | ||
# check that no NaN values are produced | ||
self.assertFalse(torch.any(torch.isnan(image_list.tensors))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wadaboa I'm still unable to see this test failing on master without your mitigation. We need to reproduce the problem, so that we are certain that we provide the right fix. Let me know your thoughts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just went through the code that initially produced the issue for me and I realized the following: the mentioned error can only happen when the input image has But in this case image and mean/std would have different ranges. So, is it something that should still be checked? Or maybe I reported an error which should alert the user to check its input and this PR is not needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for investigating. I agree, there are too many corner-cases to cover for uint8 inputs. This specific piece of code generally expects the input image to be floating type, so if it's not it might be worth throwing an error. I think what we should do is close this PR to keep things clean and open a new one where you put a type check on top of normalize and throw an exception if the input is not floating point. What you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. I'll close this PR and open a new one for the type check. Thank you. |
||
|
||
def test_normalize_float(self): | ||
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) | ||
image = [torch.rand(3, 200, 300)] | ||
targets = [{'boxes': torch.rand(3, 4)}] | ||
image_list, _ = transform(image, targets) # noqa: F841 | ||
# check that original images still have float32 dtype | ||
for img in image: | ||
self.assertTrue(img.dtype == torch.float32) | ||
# check that the resulting images have float32 dtype | ||
self.assertTrue(image_list.tensors.dtype == torch.float32) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.