Skip to content

Commit c547e5c

Browse files
authored
Deprecated F_t.center_crop, F_t.five_crop, F_t.ten_crop (#2568)
- Updated docs - Put warning in the code - Updated tests
1 parent a75fdd4 commit c547e5c

File tree

2 files changed

+91
-80
lines changed

2 files changed

+91
-80
lines changed

test/test_functional_tensor.py

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,24 @@ def test_hflip(self):
6464

6565
def test_crop(self):
6666
script_crop = torch.jit.script(F_t.crop)
67-
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
68-
img_tensor_clone = img_tensor.clone()
69-
top = random.randint(0, 15)
70-
left = random.randint(0, 15)
71-
height = random.randint(1, 16 - top)
72-
width = random.randint(1, 16 - left)
73-
img_cropped = F_t.crop(img_tensor, top, left, height, width)
74-
img_PIL = transforms.ToPILImage()(img_tensor)
75-
img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
76-
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
77-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
78-
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
79-
"functional_tensor crop not working")
80-
# scriptable function test
81-
cropped_img_script = script_crop(img_tensor, top, left, height, width)
82-
self.assertTrue(torch.equal(img_cropped, cropped_img_script))
67+
68+
img_tensor, pil_img = self._create_data(16, 18)
69+
70+
test_configs = [
71+
(1, 2, 4, 5), # crop inside top-left corner
72+
(2, 12, 3, 4), # crop inside top-right corner
73+
(8, 3, 5, 6), # crop inside bottom-left corner
74+
(8, 11, 4, 3), # crop inside bottom-right corner
75+
]
76+
77+
for top, left, height, width in test_configs:
78+
pil_img_cropped = F.crop(pil_img, top, left, height, width)
79+
80+
img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
81+
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
82+
83+
img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
84+
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
8385

8486
def test_hsv2rgb(self):
8587
shape = (3, 100, 150)
@@ -198,71 +200,47 @@ def test_rgb_to_grayscale(self):
198200
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
199201

200202
def test_center_crop(self):
201-
script_center_crop = torch.jit.script(F_t.center_crop)
202-
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
203-
img_tensor_clone = img_tensor.clone()
204-
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
205-
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
206-
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
207-
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
208-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
209-
# scriptable function test
210-
cropped_script = script_center_crop(img_tensor, [10, 10])
211-
self.assertTrue(torch.equal(cropped_script, cropped_tensor))
203+
script_center_crop = torch.jit.script(F.center_crop)
204+
205+
img_tensor, pil_img = self._create_data(32, 34)
206+
207+
cropped_pil_image = F.center_crop(pil_img, [10, 11])
208+
209+
cropped_tensor = F.center_crop(img_tensor, [10, 11])
210+
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
211+
212+
cropped_tensor = script_center_crop(img_tensor, [10, 11])
213+
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
212214

213215
def test_five_crop(self):
214-
script_five_crop = torch.jit.script(F_t.five_crop)
215-
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
216-
img_tensor_clone = img_tensor.clone()
217-
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
218-
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
219-
self.assertTrue(torch.equal(cropped_tensor[0],
220-
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
221-
self.assertTrue(torch.equal(cropped_tensor[1],
222-
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
223-
self.assertTrue(torch.equal(cropped_tensor[2],
224-
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
225-
self.assertTrue(torch.equal(cropped_tensor[3],
226-
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
227-
self.assertTrue(torch.equal(cropped_tensor[4],
228-
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
229-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
230-
# scriptable function test
231-
cropped_script = script_five_crop(img_tensor, [10, 10])
232-
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
233-
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
216+
script_five_crop = torch.jit.script(F.five_crop)
217+
218+
img_tensor, pil_img = self._create_data(32, 34)
219+
220+
cropped_pil_images = F.five_crop(pil_img, [10, 11])
221+
222+
cropped_tensors = F.five_crop(img_tensor, [10, 11])
223+
for i in range(5):
224+
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
225+
226+
cropped_tensors = script_five_crop(img_tensor, [10, 11])
227+
for i in range(5):
228+
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
234229

235230
def test_ten_crop(self):
236-
script_ten_crop = torch.jit.script(F_t.ten_crop)
237-
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
238-
img_tensor_clone = img_tensor.clone()
239-
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
240-
cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
241-
self.assertTrue(torch.equal(cropped_tensor[0],
242-
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
243-
self.assertTrue(torch.equal(cropped_tensor[1],
244-
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
245-
self.assertTrue(torch.equal(cropped_tensor[2],
246-
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
247-
self.assertTrue(torch.equal(cropped_tensor[3],
248-
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
249-
self.assertTrue(torch.equal(cropped_tensor[4],
250-
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
251-
self.assertTrue(torch.equal(cropped_tensor[5],
252-
(transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
253-
self.assertTrue(torch.equal(cropped_tensor[6],
254-
(transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
255-
self.assertTrue(torch.equal(cropped_tensor[7],
256-
(transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
257-
self.assertTrue(torch.equal(cropped_tensor[8],
258-
(transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
259-
self.assertTrue(torch.equal(cropped_tensor[9],
260-
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
261-
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
262-
# scriptable function test
263-
cropped_script = script_ten_crop(img_tensor, [10, 10])
264-
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
265-
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
231+
script_ten_crop = torch.jit.script(F.ten_crop)
232+
233+
img_tensor, pil_img = self._create_data(32, 34)
234+
235+
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
236+
237+
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
238+
for i in range(10):
239+
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
240+
241+
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
242+
for i in range(10):
243+
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
266244

267245
def test_pad(self):
268246
script_fn = torch.jit.script(F_t.pad)

torchvision/transforms/functional_tensor.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
240240

241241

242242
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
243-
"""Crop the Image Tensor and resize it to desired size.
243+
"""DEPRECATED. Crop the Image Tensor and resize it to desired size.
244+
245+
.. warning::
246+
247+
This method is deprecated and will be removed in future releases.
248+
Please, use ``F.center_crop`` instead.
244249
245250
Args:
246251
img (Tensor): Image to be cropped.
@@ -250,6 +255,11 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
250255
Returns:
251256
Tensor: Cropped image.
252257
"""
258+
warnings.warn(
259+
"This method is deprecated and will be removed in future releases. "
260+
"Please, use ``F.center_crop`` instead."
261+
)
262+
253263
if not _is_tensor_a_torch_image(img):
254264
raise TypeError('tensor is not a torch image.')
255265

@@ -268,8 +278,15 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
268278

269279

270280
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
271-
"""Crop the given Image Tensor into four corners and the central crop.
281+
"""DEPRECATED. Crop the given Image Tensor into four corners and the central crop.
282+
283+
.. warning::
284+
285+
This method is deprecated and will be removed in future releases.
286+
Please, use ``F.five_crop`` instead.
287+
272288
.. Note::
289+
273290
This transform returns a List of Tensors and there may be a
274291
mismatch in the number of inputs and targets your ``Dataset`` returns.
275292
@@ -283,6 +300,11 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
283300
List: List (tl, tr, bl, br, center)
284301
Corresponding top left, top right, bottom left, bottom right and center crop.
285302
"""
303+
warnings.warn(
304+
"This method is deprecated and will be removed in future releases. "
305+
"Please, use ``F.five_crop`` instead."
306+
)
307+
286308
if not _is_tensor_a_torch_image(img):
287309
raise TypeError('tensor is not a torch image.')
288310

@@ -304,10 +326,16 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
304326

305327

306328
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
307-
"""Crop the given Image Tensor into four corners and the central crop plus the
329+
"""DEPRECATED. Crop the given Image Tensor into four corners and the central crop plus the
308330
flipped version of these (horizontal flipping is used by default).
309331
332+
.. warning::
333+
334+
This method is deprecated and will be removed in future releases.
335+
Please, use ``F.ten_crop`` instead.
336+
310337
.. Note::
338+
311339
This transform returns a List of images and there may be a
312340
mismatch in the number of inputs and targets your ``Dataset`` returns.
313341
@@ -323,6 +351,11 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
323351
Corresponding top left, top right, bottom left, bottom right and center crop
324352
and same for the flipped image's tensor.
325353
"""
354+
warnings.warn(
355+
"This method is deprecated and will be removed in future releases. "
356+
"Please, use ``F.ten_crop`` instead."
357+
)
358+
326359
if not _is_tensor_a_torch_image(img):
327360
raise TypeError('tensor is not a torch image.')
328361

0 commit comments

Comments
 (0)