diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index f4edc0f7f07..aab9d3d9b02 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -64,22 +64,24 @@ def test_hflip(self): def test_crop(self): script_crop = torch.jit.script(F_t.crop) - img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) - img_tensor_clone = img_tensor.clone() - top = random.randint(0, 15) - left = random.randint(0, 15) - height = random.randint(1, 16 - top) - width = random.randint(1, 16 - left) - img_cropped = F_t.crop(img_tensor, top, left, height, width) - img_PIL = transforms.ToPILImage()(img_tensor) - img_PIL_cropped = F.crop(img_PIL, top, left, height, width) - img_cropped_GT = transforms.ToTensor()(img_PIL_cropped) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) - self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), - "functional_tensor crop not working") - # scriptable function test - cropped_img_script = script_crop(img_tensor, top, left, height, width) - self.assertTrue(torch.equal(img_cropped, cropped_img_script)) + + img_tensor, pil_img = self._create_data(16, 18) + + test_configs = [ + (1, 2, 4, 5), # crop inside top-left corner + (2, 12, 3, 4), # crop inside top-right corner + (8, 3, 5, 6), # crop inside bottom-left corner + (8, 11, 4, 3), # crop inside bottom-right corner + ] + + for top, left, height, width in test_configs: + pil_img_cropped = F.crop(pil_img, top, left, height, width) + + img_tensor_cropped = F.crop(img_tensor, top, left, height, width) + self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) + + img_tensor_cropped = script_crop(img_tensor, top, left, height, width) + self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) def test_hsv2rgb(self): shape = (3, 100, 150) @@ -198,71 +200,47 @@ def test_rgb_to_grayscale(self): self.assertTrue(torch.equal(grayscale_script, grayscale_tensor)) def test_center_crop(self): - script_center_crop = torch.jit.script(F_t.center_crop) - img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) - img_tensor_clone = img_tensor.clone() - cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) - cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10]) - cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) - self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor)) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) - # scriptable function test - cropped_script = script_center_crop(img_tensor, [10, 10]) - self.assertTrue(torch.equal(cropped_script, cropped_tensor)) + script_center_crop = torch.jit.script(F.center_crop) + + img_tensor, pil_img = self._create_data(32, 34) + + cropped_pil_image = F.center_crop(pil_img, [10, 11]) + + cropped_tensor = F.center_crop(img_tensor, [10, 11]) + self.compareTensorToPIL(cropped_tensor, cropped_pil_image) + + cropped_tensor = script_center_crop(img_tensor, [10, 11]) + self.compareTensorToPIL(cropped_tensor, cropped_pil_image) def test_five_crop(self): - script_five_crop = torch.jit.script(F_t.five_crop) - img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) - img_tensor_clone = img_tensor.clone() - cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) - cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10]) - self.assertTrue(torch.equal(cropped_tensor[0], - (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[1], - (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[2], - (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[3], - (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[4], - (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) - # scriptable function test - cropped_script = script_five_crop(img_tensor, [10, 10]) - for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor): - self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img)) + script_five_crop = torch.jit.script(F.five_crop) + + img_tensor, pil_img = self._create_data(32, 34) + + cropped_pil_images = F.five_crop(pil_img, [10, 11]) + + cropped_tensors = F.five_crop(img_tensor, [10, 11]) + for i in range(5): + self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) + + cropped_tensors = script_five_crop(img_tensor, [10, 11]) + for i in range(5): + self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) def test_ten_crop(self): - script_ten_crop = torch.jit.script(F_t.ten_crop) - img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) - img_tensor_clone = img_tensor.clone() - cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) - cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10]) - self.assertTrue(torch.equal(cropped_tensor[0], - (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[1], - (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[2], - (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[3], - (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[4], - (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[5], - (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[6], - (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[7], - (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[8], - (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(cropped_tensor[9], - (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) - self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) - # scriptable function test - cropped_script = script_ten_crop(img_tensor, [10, 10]) - for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor): - self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img)) + script_ten_crop = torch.jit.script(F.ten_crop) + + img_tensor, pil_img = self._create_data(32, 34) + + cropped_pil_images = F.ten_crop(pil_img, [10, 11]) + + cropped_tensors = F.ten_crop(img_tensor, [10, 11]) + for i in range(10): + self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) + + cropped_tensors = script_ten_crop(img_tensor, [10, 11]) + for i in range(10): + self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) def test_pad(self): script_fn = torch.jit.script(F_t.pad) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 433575ac6e2..0b14e9acab7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -240,7 +240,12 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: - """Crop the Image Tensor and resize it to desired size. + """DEPRECATED. Crop the Image Tensor and resize it to desired size. + + .. warning:: + + This method is deprecated and will be removed in future releases. + Please, use ``F.center_crop`` instead. Args: img (Tensor): Image to be cropped. @@ -250,6 +255,11 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: Returns: Tensor: Cropped image. """ + warnings.warn( + "This method is deprecated and will be removed in future releases. " + "Please, use ``F.center_crop`` instead." + ) + if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') @@ -268,8 +278,15 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: - """Crop the given Image Tensor into four corners and the central crop. + """DEPRECATED. Crop the given Image Tensor into four corners and the central crop. + + .. warning:: + + This method is deprecated and will be removed in future releases. + Please, use ``F.five_crop`` instead. + .. Note:: + This transform returns a List of Tensors and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. @@ -283,6 +300,11 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: List: List (tl, tr, bl, br, center) Corresponding top left, top right, bottom left, bottom right and center crop. """ + warnings.warn( + "This method is deprecated and will be removed in future releases. " + "Please, use ``F.five_crop`` instead." + ) + if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') @@ -304,10 +326,16 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: - """Crop the given Image Tensor into four corners and the central crop plus the + """DEPRECATED. Crop the given Image Tensor into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). + .. warning:: + + This method is deprecated and will be removed in future releases. + Please, use ``F.ten_crop`` instead. + .. Note:: + This transform returns a List of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. @@ -323,6 +351,11 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa Corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image's tensor. """ + warnings.warn( + "This method is deprecated and will be removed in future releases. " + "Please, use ``F.ten_crop`` instead." + ) + if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.')