diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 165c23dbdb8..8cebe666b50 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -59,12 +59,11 @@ def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match= _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs) -def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs): - # TODO: change the name: it's not a method, it's a class. +def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs): meth_kwargs = meth_kwargs or {} # test for class interface - f = method(**meth_kwargs) + f = transform_cls(**meth_kwargs) scripted_fn = torch.jit.script(f) tensor, pil_img = _create_data(26, 34, channels, device=device) @@ -86,7 +85,7 @@ def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_matc _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) with get_tmp_dir() as tmp_dir: - scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt")) + scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt")) def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index cf75034ee6c..c43e52db852 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -412,7 +412,10 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con need_cast = True img = img.to(torch.float32) - img = torch_pad(img, p, mode=padding_mode, value=float(fill)) + if padding_mode in ("reflect", "replicate"): + img = torch_pad(img, p, mode=padding_mode) + else: + img = torch_pad(img, p, mode=padding_mode, value=float(fill)) if need_squeeze: img = img.squeeze(dim=0)