Skip to content

Commit e50c2e3

Browse files
authored
Improved docs and tests for (#2371)
- RandomCrop: crop with padding using all commonly supported modes
1 parent 4480603 commit e50c2e3

File tree

4 files changed

+35
-29
lines changed

4 files changed

+35
-29
lines changed

test/test_transforms_tensor.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,29 @@ def _test_functional_geom_op(self, func, fn_kwargs):
2626
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
2727
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
2828

29-
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
30-
if fn_kwargs is None:
31-
fn_kwargs = {}
29+
def _test_class_geom_op(self, method, meth_kwargs=None):
3230
if meth_kwargs is None:
3331
meth_kwargs = {}
32+
3433
tensor, pil_img = self._create_data(height=10, width=10)
35-
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
36-
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
34+
# test for class interface
35+
f = getattr(T, method)(**meth_kwargs)
36+
scripted_fn = torch.jit.script(f)
37+
38+
# set seed to reproduce the same transformation for tensor and PIL image
39+
torch.manual_seed(12)
40+
transformed_tensor = f(tensor)
41+
torch.manual_seed(12)
42+
transformed_pil_img = f(pil_img)
3743
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
3844

39-
scripted_fn = torch.jit.script(getattr(F, func))
40-
transformed_tensor_script = scripted_fn(tensor, **fn_kwargs)
45+
torch.manual_seed(12)
46+
transformed_tensor_script = scripted_fn(tensor)
4147
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
4248

43-
# test for class interface
44-
f = getattr(T, method)(**meth_kwargs)
45-
scripted_fn = torch.jit.script(f)
46-
scripted_fn(tensor)
49+
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
50+
self._test_functional_geom_op(func, fn_kwargs)
51+
self._test_class_geom_op(method, meth_kwargs)
4752

4853
def test_random_horizontal_flip(self):
4954
self._test_geom_op('hflip', 'RandomHorizontalFlip')
@@ -107,21 +112,20 @@ def test_crop(self):
107112
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
108113
)
109114

110-
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
111-
# Test torchscript of transforms.RandomCrop with size as int
112-
f = T.RandomCrop(size=5)
113-
scripted_fn = torch.jit.script(f)
114-
scripted_fn(tensor)
115-
116-
# Test torchscript of transforms.RandomCrop with size as [int, ]
117-
f = T.RandomCrop(size=[5, ], padding=[2, ])
118-
scripted_fn = torch.jit.script(f)
119-
scripted_fn(tensor)
120-
121-
# Test torchscript of transforms.RandomCrop with size as list
122-
f = T.RandomCrop(size=[6, 6])
123-
scripted_fn = torch.jit.script(f)
124-
scripted_fn(tensor)
115+
sizes = [5, [5, ], [6, 6]]
116+
padding_configs = [
117+
{"padding_mode": "constant", "fill": 0},
118+
{"padding_mode": "constant", "fill": 10},
119+
{"padding_mode": "constant", "fill": 20},
120+
{"padding_mode": "edge"},
121+
{"padding_mode": "reflect"},
122+
]
123+
124+
for size in sizes:
125+
for padding_config in padding_configs:
126+
config = dict(padding_config)
127+
config["size"] = size
128+
self._test_class_geom_op("RandomCrop", config)
125129

126130
def test_center_crop(self):
127131
fn_kwargs = {"output_size": (4, 5)}

torchvision/transforms/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
371371
length 3, it is used to fill R, G, B channels respectively.
372372
This value is only used when the padding_mode is constant. Only int value is supported for Tensors.
373373
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
374-
Only "constant" is supported for Tensors as of now.
374+
Mode symmetric is not yet supported for Tensor inputs.
375375
376376
- constant: pads with a constant value, this value is specified with fill
377377

torchvision/transforms/functional_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
368368
list of length 1: ``[padding, ]``.
369369
fill (int): Pixel fill value for constant fill. Default is 0.
370370
This value is only used when the padding_mode is constant
371-
padding_mode (str): Type of padding. Only "constant" is supported for Tensors as of now.
371+
padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant.
372+
Mode symmetric is not yet supported for Tensor inputs.
372373
373374
- constant: pads with a constant value, this value is specified with fill
374375

torchvision/transforms/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class Pad(torch.nn.Module):
305305
length 3, it is used to fill R, G, B channels respectively.
306306
This value is only used when the padding_mode is constant
307307
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
308-
Default is constant. Only "constant" is supported for Tensors as of now.
308+
Default is constant. Mode symmetric is not yet supported for Tensor inputs.
309309
310310
- constant: pads with a constant value, this value is specified with fill
311311
@@ -469,6 +469,7 @@ class RandomCrop(torch.nn.Module):
469469
length 3, it is used to fill R, G, B channels respectively.
470470
This value is only used when the padding_mode is constant
471471
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
472+
Mode symmetric is not yet supported for Tensor inputs.
472473
473474
- constant: pads with a constant value, this value is specified with fill
474475

0 commit comments

Comments
 (0)