Skip to content

Add float cast in GeneralizedRCNN normalize #3238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ def test_transform_copy_targets(self):
self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes']))
self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes']))

def test_normalize_integer(self):
transform = GeneralizedRCNNTransform(
300, 500,
torch.randint(0, 255, (3,)),
torch.randint(0, 255, (3,))
)
image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)]
targets = [{'boxes': torch.randint(0, 255, (3, 4))}]
image_list, _ = transform(image, targets) # noqa: F841
# check that original images still have uint8 dtype
for img in image:
self.assertTrue(img.dtype == torch.uint8)
# check that the resulting images have float32 dtype
self.assertTrue(image_list.tensors.dtype == torch.float32)
# check that no NaN values are produced
self.assertFalse(torch.any(torch.isnan(image_list.tensors)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wadaboa I'm still unable to see this test failing on master without your mitigation. We need to reproduce the problem, so that we are certain that we provide the right fix. Let me know your thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just went through the code that initially produced the issue for me and I realized the following: the mentioned error can only happen when the input image has uint8 dtype (so, in [0, 255] range) and the mean/std parameters are instead in [0, 1]. In this way, the output given by normalization will be all inf (not NaN as shown in the tests).

But in this case image and mean/std would have different ranges. So, is it something that should still be checked? Or maybe I reported an error which should alert the user to check its input and this PR is not needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating. I agree, there are too many corner-cases to cover for uint8 inputs. This specific piece of code generally expects the input image to be floating type, so if it's not it might be worth throwing an error.

I think what we should do is close this PR to keep things clean and open a new one where you put a type check on top of normalize and throw an exception if the input is not floating point. What you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll close this PR and open a new one for the type check. Thank you.


def test_normalize_float(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.rand(3, 200, 300)]
targets = [{'boxes': torch.rand(3, 4)}]
image_list, _ = transform(image, targets) # noqa: F841
# check that original images still have float32 dtype
for img in image:
self.assertTrue(img.dtype == torch.float32)
# check that the resulting images have float32 dtype
self.assertTrue(image_list.tensors.dtype == torch.float32)


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def forward(self,
return image_list, targets

def normalize(self, image):
if not image.is_floating_point():
image = image.to(torch.float32)
dtype, device = image.dtype, image.device
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
Expand Down