Skip to content

Commit cc53cd0

Browse files
vfdev-5datumbox
andauthored
Fixed issue with padding on CI (#5875)
* Fixed issue with padding on CI * Disabled failing tests with color_jitter * Remove Jitter workaround Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 01b0a00 commit cc53cd0

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

test/test_transforms_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,11 @@ def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=
5959
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
6060

6161

62-
def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
63-
# TODO: change the name: it's not a method, it's a class.
62+
def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
6463
meth_kwargs = meth_kwargs or {}
6564

6665
# test for class interface
67-
f = method(**meth_kwargs)
66+
f = transform_cls(**meth_kwargs)
6867
scripted_fn = torch.jit.script(f)
6968

7069
tensor, pil_img = _create_data(26, 34, channels, device=device)
@@ -86,7 +85,7 @@ def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_matc
8685
_test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
8786

8887
with get_tmp_dir() as tmp_dir:
89-
scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
88+
scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
9089

9190

9291
def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):

torchvision/transforms/functional_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,10 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
412412
need_cast = True
413413
img = img.to(torch.float32)
414414

415-
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
415+
if padding_mode in ("reflect", "replicate"):
416+
img = torch_pad(img, p, mode=padding_mode)
417+
else:
418+
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
416419

417420
if need_squeeze:
418421
img = img.squeeze(dim=0)

0 commit comments

Comments
 (0)