Skip to content

Commit 695605e

Browse files
committed
Added ref tests for shear X/Y
1 parent 97eddc5 commit 695605e

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

test/test_transforms_tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,40 @@ def test_autoaugment_save(augmentation, tmpdir):
725725
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
726726

727727

728+
def test_autoaugment__op_apply_shear():
729+
# We check that torchvision's implementation of shear is equivalent
730+
# to official one:
731+
# https://github.com/google-research/augmix/blob/master/augmentations.py#L81-L114
732+
from PIL import Image
733+
734+
image_size = 10
735+
736+
def shear(pil_img, level, mode="X"):
737+
if mode == "X":
738+
matrix = (1, level, 0, 0, 1, 0)
739+
elif mode == "Y":
740+
matrix = (1, 0, 0, level, 1, 0)
741+
742+
return pil_img.transform(
743+
(image_size, image_size),
744+
Image.AFFINE,
745+
matrix,
746+
resample=Image.NEAREST
747+
)
748+
749+
from torchvision.transforms.autoaugment import _apply_op
750+
751+
t_img, pil_img = _create_data(image_size, image_size)
752+
753+
level = 0.24
754+
for mode in ["X", "Y"]:
755+
expected_out = shear(pil_img, level, mode=mode)
756+
out = _apply_op(
757+
t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=F.InterpolationMode.NEAREST, fill=0
758+
)
759+
_assert_approx_equal_tensor_to_pil(out, expected_out)
760+
761+
728762
@pytest.mark.parametrize("device", cpu_and_gpu())
729763
@pytest.mark.parametrize(
730764
"config",

0 commit comments

Comments
 (0)