From 54dd34277db2533093a50c7d5dcc9ac78b0e2508 Mon Sep 17 00:00:00 2001 From: Alessio Falai Date: Mon, 11 Jan 2021 09:00:20 +0100 Subject: [PATCH 1/2] add float cast in generalized rcnn normalize --- test/test_models_detection_utils.py | 26 +++++++++++++++++++++++ torchvision/models/detection/transform.py | 2 ++ 2 files changed, 28 insertions(+) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index bfb26f24eae..9f28fd0f07a 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -58,6 +58,32 @@ def test_transform_copy_targets(self): self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) + def test_normalize_integer(self): + transform = GeneralizedRCNNTransform( + 300, 500, + torch.randint(0, 255, (3,)), + torch.randint(0, 255, (3,)) + ) + image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] + targets = [{'boxes': torch.randint(0, 255, (3, 4))}] + image_list, _ = transform(image, targets) # noqa: F841 + # check that original images still have uint8 dtype + for img in image: + self.assertTrue(img.dtype == torch.uint8) + # check that the resulting images have float32 dtype + self.assertTrue(image_list.tensors.dtype == torch.float32) + + def test_normalize_float(self): + transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) + image = [torch.rand(3, 200, 300)] + targets = [{'boxes': torch.rand(3, 4)}] + image_list, _ = transform(image, targets) # noqa: F841 + # check that original images still have float32 dtype + for img in image: + self.assertTrue(img.dtype == torch.float32) + # check that the resulting images have float32 dtype + self.assertTrue(image_list.tensors.dtype == torch.float32) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 0d95361eedb..7e14f123c84 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -117,6 +117,8 @@ def forward(self, return image_list, targets def normalize(self, image): + if image.dtype not in (torch.float, torch.double, torch.half): + image = image.float() dtype, device = image.dtype, image.device mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device) From ec595fd7c8b3b871d0ea7fa1ac9557dc908c0170 Mon Sep 17 00:00:00 2001 From: Alessio Falai Date: Mon, 11 Jan 2021 21:45:02 +0100 Subject: [PATCH 2/2] edit float cast / add nan check to normalize integer test --- test/test_models_detection_utils.py | 2 ++ torchvision/models/detection/transform.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 9f28fd0f07a..4888eb0c573 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -72,6 +72,8 @@ def test_normalize_integer(self): self.assertTrue(img.dtype == torch.uint8) # check that the resulting images have float32 dtype self.assertTrue(image_list.tensors.dtype == torch.float32) + # check that no NaN values are produced + self.assertFalse(torch.any(torch.isnan(image_list.tensors))) def test_normalize_float(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 7e14f123c84..e870bb7c63d 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -117,8 +117,8 @@ def forward(self, return image_list, targets def normalize(self, image): - if image.dtype not in (torch.float, torch.double, torch.half): - image = image.float() + if not image.is_floating_point(): + image = image.to(torch.float32) dtype, device = image.dtype, image.device mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device)