Skip to content

Commit cec7ea7

Browse files
PyExtremefmassa
authored andcommitted
Add scriptable transform: center_crop, five crop and ten_crop (#1615)
* add scriptable transform: center_crop * add test: center_crop * add scriptable transform: five_crop * add scriptable transform: five_crop * add scriptable transform: fix minor issues
1 parent e3a1305 commit cec7ea7

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

test/test_functional_tensor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,53 @@ def test_rgb_to_grayscale(self):
7676
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
7777
self.assertLess(max_diff, 1.0001)
7878

79+
def test_center_crop(self):
80+
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
81+
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
82+
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
83+
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
84+
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
85+
86+
def test_five_crop(self):
87+
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
88+
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
89+
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
90+
self.assertTrue(torch.equal(cropped_tensor[0],
91+
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
92+
self.assertTrue(torch.equal(cropped_tensor[1],
93+
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
94+
self.assertTrue(torch.equal(cropped_tensor[2],
95+
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
96+
self.assertTrue(torch.equal(cropped_tensor[3],
97+
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
98+
self.assertTrue(torch.equal(cropped_tensor[4],
99+
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
100+
101+
def test_ten_crop(self):
102+
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
103+
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
104+
cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
105+
self.assertTrue(torch.equal(cropped_tensor[0],
106+
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
107+
self.assertTrue(torch.equal(cropped_tensor[1],
108+
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
109+
self.assertTrue(torch.equal(cropped_tensor[2],
110+
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
111+
self.assertTrue(torch.equal(cropped_tensor[3],
112+
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
113+
self.assertTrue(torch.equal(cropped_tensor[4],
114+
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
115+
self.assertTrue(torch.equal(cropped_tensor[5],
116+
(transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
117+
self.assertTrue(torch.equal(cropped_tensor[6],
118+
(transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
119+
self.assertTrue(torch.equal(cropped_tensor[7],
120+
(transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
121+
self.assertTrue(torch.equal(cropped_tensor[8],
122+
(transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
123+
self.assertTrue(torch.equal(cropped_tensor[9],
124+
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
125+
79126

80127
if __name__ == '__main__':
81128
unittest.main()

torchvision/transforms/functional_tensor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,97 @@ def adjust_saturation(img, saturation_factor):
125125
return _blend(img, rgb_to_grayscale(img), saturation_factor)
126126

127127

128+
def center_crop(img, output_size):
129+
"""Crop the Image Tensor and resize it to desired size.
130+
131+
Args:
132+
img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
133+
output_size (sequence or int): (height, width) of the crop box. If int,
134+
it is used for both directions
135+
136+
Returns:
137+
Tensor: Cropped image.
138+
"""
139+
if not F._is_tensor_image(img):
140+
raise TypeError('tensor is not a torch image.')
141+
142+
_, image_width, image_height = img.size()
143+
crop_height, crop_width = output_size
144+
crop_top = int(round((image_height - crop_height) / 2.))
145+
crop_left = int(round((image_width - crop_width) / 2.))
146+
147+
return crop(img, crop_top, crop_left, crop_height, crop_width)
148+
149+
150+
def five_crop(img, size):
151+
"""Crop the given Image Tensor into four corners and the central crop.
152+
.. Note::
153+
This transform returns a tuple of Tensors and there may be a
154+
mismatch in the number of inputs and targets your ``Dataset`` returns.
155+
156+
Args:
157+
size (sequence or int): Desired output size of the crop. If size is an
158+
int instead of sequence like (h, w), a square crop (size, size) is
159+
made.
160+
161+
Returns:
162+
tuple: tuple (tl, tr, bl, br, center)
163+
Corresponding top left, top right, bottom left, bottom right and center crop.
164+
"""
165+
if not F._is_tensor_image(img):
166+
raise TypeError('tensor is not a torch image.')
167+
168+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
169+
170+
_, image_width, image_height = img.size()
171+
crop_height, crop_width = size
172+
if crop_width > image_width or crop_height > image_height:
173+
msg = "Requested crop size {} is bigger than input size {}"
174+
raise ValueError(msg.format(size, (image_height, image_width)))
175+
176+
tl = crop(img, 0, 0, crop_width, crop_height)
177+
tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
178+
bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
179+
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
180+
center = center_crop(img, (crop_height, crop_width))
181+
182+
return (tl, tr, bl, br, center)
183+
184+
185+
def ten_crop(img, size, vertical_flip=False):
186+
"""Crop the given Image Tensor into four corners and the central crop plus the
187+
flipped version of these (horizontal flipping is used by default).
188+
.. Note::
189+
This transform returns a tuple of images and there may be a
190+
mismatch in the number of inputs and targets your ``Dataset`` returns.
191+
192+
Args:
193+
size (sequence or int): Desired output size of the crop. If size is an
194+
int instead of sequence like (h, w), a square crop (size, size) is
195+
made.
196+
vertical_flip (bool): Use vertical flipping instead of horizontal
197+
198+
Returns:
199+
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
200+
Corresponding top left, top right, bottom left, bottom right and center crop
201+
and same for the flipped image's tensor.
202+
"""
203+
if not F._is_tensor_image(img):
204+
raise TypeError('tensor is not a torch image.')
205+
206+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
207+
first_five = five_crop(img, size)
208+
209+
if vertical_flip:
210+
img = vflip(img)
211+
else:
212+
img = hflip(img)
213+
214+
second_five = five_crop(img, size)
215+
216+
return first_five + second_five
217+
218+
128219
def _blend(img1, img2, ratio):
129220
bound = 1 if img1.dtype.is_floating_point else 255
130221
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)

0 commit comments

Comments
 (0)