Skip to content

Commit a99b6bd

Browse files
authored
Unified Tensor/PIL crop (#2342)
* [WIP] Unified Tensor/PIL crop * Fixed misplaced type annotation * Fixed tests - crop with padding - other tests using mising private functions: _is_pil_image, _get_image_size * Unified CenterCrop and F.center_crop - sorted includes in transforms.py - used py3 annotations * Unified FiveCrop and F.five_crop * Improved tests and docs * Unified TenCrop and F.ten_crop * Removed useless typing in functional_pil * Updated code according to the review - removed useless torch.jit.export - added missing typing return type - fixed F.F_pil._is_pil_image -> F._is_pil_image * Removed useless torch.jit.export * Improved code according to the review
1 parent 446eac6 commit a99b6bd

File tree

5 files changed

+398
-169
lines changed

5 files changed

+398
-169
lines changed

test/test_transforms_tensor.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,120 @@ def test_pad(self):
9999
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
100100
)
101101

102+
def test_crop(self):
103+
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
104+
# Test transforms.RandomCrop with size and padding as tuple
105+
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
106+
self._test_geom_op(
107+
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
108+
)
109+
110+
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
111+
# Test torchscript of transforms.RandomCrop with size as int
112+
f = T.RandomCrop(size=5)
113+
scripted_fn = torch.jit.script(f)
114+
scripted_fn(tensor)
115+
116+
# Test torchscript of transforms.RandomCrop with size as [int, ]
117+
f = T.RandomCrop(size=[5, ], padding=[2, ])
118+
scripted_fn = torch.jit.script(f)
119+
scripted_fn(tensor)
120+
121+
# Test torchscript of transforms.RandomCrop with size as list
122+
f = T.RandomCrop(size=[6, 6])
123+
scripted_fn = torch.jit.script(f)
124+
scripted_fn(tensor)
125+
126+
def test_center_crop(self):
127+
fn_kwargs = {"output_size": (4, 5)}
128+
meth_kwargs = {"size": (4, 5), }
129+
self._test_geom_op(
130+
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
131+
)
132+
fn_kwargs = {"output_size": (5,)}
133+
meth_kwargs = {"size": (5, )}
134+
self._test_geom_op(
135+
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
136+
)
137+
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
138+
# Test torchscript of transforms.CenterCrop with size as int
139+
f = T.CenterCrop(size=5)
140+
scripted_fn = torch.jit.script(f)
141+
scripted_fn(tensor)
142+
143+
# Test torchscript of transforms.CenterCrop with size as [int, ]
144+
f = T.CenterCrop(size=[5, ])
145+
scripted_fn = torch.jit.script(f)
146+
scripted_fn(tensor)
147+
148+
# Test torchscript of transforms.CenterCrop with size as tuple
149+
f = T.CenterCrop(size=(6, 6))
150+
scripted_fn = torch.jit.script(f)
151+
scripted_fn(tensor)
152+
153+
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
154+
if fn_kwargs is None:
155+
fn_kwargs = {}
156+
if meth_kwargs is None:
157+
meth_kwargs = {}
158+
tensor, pil_img = self._create_data(height=20, width=20)
159+
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs)
160+
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs)
161+
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
162+
self.assertEqual(len(transformed_t_list), out_length)
163+
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
164+
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
165+
166+
scripted_fn = torch.jit.script(getattr(F, func))
167+
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
168+
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
169+
self.assertEqual(len(transformed_t_list_script), out_length)
170+
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
171+
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
172+
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
173+
174+
# test for class interface
175+
f = getattr(T, method)(**meth_kwargs)
176+
scripted_fn = torch.jit.script(f)
177+
output = scripted_fn(tensor)
178+
self.assertEqual(len(output), len(transformed_t_list_script))
179+
180+
def test_five_crop(self):
181+
fn_kwargs = meth_kwargs = {"size": (5,)}
182+
self._test_geom_op_list_output(
183+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
184+
)
185+
fn_kwargs = meth_kwargs = {"size": [5, ]}
186+
self._test_geom_op_list_output(
187+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
188+
)
189+
fn_kwargs = meth_kwargs = {"size": (4, 5)}
190+
self._test_geom_op_list_output(
191+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
192+
)
193+
fn_kwargs = meth_kwargs = {"size": [4, 5]}
194+
self._test_geom_op_list_output(
195+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
196+
)
197+
198+
def test_ten_crop(self):
199+
fn_kwargs = meth_kwargs = {"size": (5,)}
200+
self._test_geom_op_list_output(
201+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
202+
)
203+
fn_kwargs = meth_kwargs = {"size": [5, ]}
204+
self._test_geom_op_list_output(
205+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
206+
)
207+
fn_kwargs = meth_kwargs = {"size": (4, 5)}
208+
self._test_geom_op_list_output(
209+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
210+
)
211+
fn_kwargs = meth_kwargs = {"size": [4, 5]}
212+
self._test_geom_op_list_output(
213+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
214+
)
215+
102216

103217
if __name__ == '__main__':
104218
unittest.main()

0 commit comments

Comments
 (0)