Skip to content
114 changes: 114 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,120 @@ def test_pad(self):
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
# Test torchscript of transforms.RandomCrop with size as int
f = T.RandomCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.RandomCrop with size as [int, ]
f = T.RandomCrop(size=[5, ], padding=[2, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.RandomCrop with size as list
f = T.RandomCrop(size=[6, 6])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
# Test torchscript of transforms.CenterCrop with size as int
f = T.CenterCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.CenterCrop with size as [int, ]
f = T.CenterCrop(size=[5, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.CenterCrop with size as tuple
f = T.CenterCrop(size=(6, 6))
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}
tensor, pil_img = self._create_data(height=20, width=20)
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs)
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

scripted_fn = torch.jit.script(getattr(F, func))
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))

# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script))

def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)


if __name__ == '__main__':
unittest.main()
Loading