Skip to content

WIP Allowing for grouped (random) transformations of tensors/tensor-like objects [proof-of-concept] #4267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
return batch_tensor


assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0, check_stride=False)


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
Expand Down
99 changes: 77 additions & 22 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,23 @@ def test_to_tensor(self, channels):
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
torch.testing.assert_close(output, input_data)
torch.testing.assert_close(output, input_data, check_stride=False)

ndarray = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False)

ndarray = np_rng.rand(height, width, channels).astype(np.float32)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1))
torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False)

# separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img)
torch.testing.assert_close(input_data, output, check_dtype=False)
torch.testing.assert_close(input_data, output, check_dtype=False, check_stride=False)

def test_to_tensor_errors(self):
height, width = 4, 4
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_pil_to_tensor(self, channels):
input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
torch.testing.assert_close(input_data, output)
torch.testing.assert_close(input_data, output, check_stride=False)

input_data = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
img = transforms.ToPILImage()(input_data)
Expand All @@ -273,13 +273,13 @@ def test_pil_to_tensor(self, channels):
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
output = trans(img) # HWC -> CHW
expected_output = (input_data * 255).byte()
torch.testing.assert_close(output, expected_output)
torch.testing.assert_close(output, expected_output, check_stride=False)

# separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img).view(torch.uint8).bool().to(torch.uint8)
torch.testing.assert_close(input_data, output)
torch.testing.assert_close(input_data, output, check_stride=False)

def test_pil_to_tensor_errors(self):
height, width = 4, 4
Expand Down Expand Up @@ -424,10 +424,10 @@ def test_pad(self):
h_padded = result[:, :padding, :]
w_padded = result[:, :, :padding]
torch.testing.assert_close(
h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps
h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False,
)
torch.testing.assert_close(
w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps
w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False,
)
pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img))
Expand Down Expand Up @@ -528,9 +528,9 @@ def test_randomness(fn, trans, config, p):
num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
transformation = trans(p=p, **config)
transformation.__repr__()
out = transformation(img)
if out == inv_img:
counts += 1

Expand Down Expand Up @@ -583,7 +583,7 @@ def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_outpu

img = transform(img_data)
assert img.mode == expected_mode
torch.testing.assert_close(expected_output, to_tensor(img).numpy())
torch.testing.assert_close(expected_output, to_tensor(img).numpy(), check_stride=False)

def test_1_channel_float_tensor_to_pil_image(self):
img_data = torch.Tensor(1, 4, 4).uniform_()
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_2_channel_ndarray_to_pil_image(self, expected_mode):
assert img.mode == expected_mode
split = img.split()
for i in range(2):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)

def test_2_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
Expand Down Expand Up @@ -725,7 +725,7 @@ def test_3_channel_ndarray_to_pil_image(self, expected_mode):
assert img.mode == expected_mode
split = img.split()
for i in range(3):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)

def test_3_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
Expand Down Expand Up @@ -782,7 +782,7 @@ def test_4_channel_ndarray_to_pil_image(self, expected_mode):
assert img.mode == expected_mode
split = img.split()
for i in range(4):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)

def test_4_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
Expand Down Expand Up @@ -1532,7 +1532,7 @@ def test_random_crop():

t = transforms.RandomCrop(48)
img = torch.ones(3, 32, 32)
with pytest.raises(ValueError, match=r"Required crop size .+ is larger then input image size .+"):
with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"):
t(img)


Expand Down Expand Up @@ -1659,7 +1659,7 @@ def test_random_erasing():
img = torch.ones(3, 128, 128)

t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.))
y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ])
y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ])
aspect_ratio = h / w
# Add some tolerance due to the rounding and int conversion used in the transform
tol = 0.05
Expand All @@ -1669,7 +1669,7 @@ def test_random_erasing():
random.seed(42)
trial = 1000
for _ in range(trial):
y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ])
y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ])
aspect_ratios.append(h / w)

count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1])
Expand Down Expand Up @@ -1730,7 +1730,7 @@ def test_randomperspective():
to_pil_image = transforms.ToPILImage()
img = to_pil_image(img)
perp = transforms.RandomPerspective()
startpoints, endpoints = perp.get_params(width, height, 0.5)
startpoints, endpoints = perp.get_start_endpoints(width, height, 0.5)
tr_img = F.perspective(img, startpoints, endpoints)
tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
tr_img = F.to_tensor(tr_img)
Expand Down Expand Up @@ -1767,7 +1767,7 @@ def test_randomperspective_fill(mode):
pixel = (pixel,)
assert pixel == tuple([fill] * num_bands)

startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
startpoints, endpoints = transforms.RandomPerspective.get_start_endpoints(width, height, 0.5)
tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
pixel = tr_img.getpixel((0, 0))

Expand Down Expand Up @@ -2062,7 +2062,7 @@ def test_random_affine():

t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
for _ in range(100):
angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
angle, translations, scale, shear = t.get_params(img, t.degrees, t.translate, t.scale, t.shear,
img_size=img.size)
assert -10 < angle < 10
assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}"
Expand Down Expand Up @@ -2094,5 +2094,60 @@ def test_random_affine():
assert t.interpolation == transforms.InterpolationMode.BILINEAR


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize('trans, config', [
(transforms.RandomInvert, {}),
(transforms.RandomPosterize, {"bits": 4}),
(transforms.RandomSolarize, {"threshold": 192}),
(transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
(transforms.RandomAutocontrast, {}),
(transforms.RandomEqualize, {})])
@pytest.mark.parametrize('p', (.5, .7))
def test_reset_randomness(trans, config, p):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

num_samples = 250
counts = 0
for _ in range(num_samples):
transformation = trans(p=p, **config, reset_auto=False)
transformation.__repr__()
out1 = transformation(img)
assert out1 == transformation(img)
transformation.wipeout_()
out2 = transformation(img)
if out1 == out2:
counts += 1

p_repeat = p**2 + (1 - p)**2
p_value = stats.binom_test(counts, num_samples, p=p_repeat)
random.setstate(random_state)
assert p_value > 0.0001, f'got counts={counts} for num_samples={num_samples}'


@pytest.mark.parametrize('trans, config', [
(transforms.RandomCrop, {'size': 10}),
(transforms.RandomOrder, {"transforms":
[transforms.GaussianBlur(kernel_size=3, reset_auto=False),
transforms.RandomCrop(size=10, reset_auto=False)]}),
(transforms.RandomResizedCrop, {'size': 10}),
(transforms.ColorJitter, {}),
(transforms.RandomRotation, {'degrees': 120}),
(transforms.RandomAffine, {'degrees': 120, 'translate': (0.1, 0.1)}),
(transforms.RandomErasing, {}),
(transforms.GaussianBlur, {"kernel_size": 3})])
def test_grouptransform(trans, config):
num_samples = 250
for i in range(num_samples):
t = transforms.GroupTransform(trans(**config, reset_auto=False))
assert t.stochastic
img = torch.arange(1024, dtype=torch.float).view(1, 32, 32).expand(3, 32, 32).contiguous()
mask = img[:1]
imgs = (img, mask)
imgs_out = t(imgs)
torch.testing.assert_close(imgs_out[0][0], imgs_out[1][0], rtol=1e-6, atol=1e-6, check_stride=False)


if __name__ == '__main__':
pytest.main([__file__])
Loading