Skip to content

Commit 7500373

Browse files
vfdev-5fmassa
authored andcommitted
Add RandomApply, RandomChoice, RandomOrder transformations (#402)
* Add RandomApply, RandomChoice, RandomOrder transformations * Rename argument `proba` to `p`
1 parent 5985869 commit 7500373

File tree

2 files changed

+159
-3
lines changed

2 files changed

+159
-3
lines changed

test/test_transforms.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,91 @@ def test_lambda(self):
258258
# Checking if Lambda can be printed as string
259259
trans.__repr__()
260260

261+
def test_random_apply(self):
262+
random_state = random.getstate()
263+
random.seed(42)
264+
random_apply_transform = transforms.RandomApply(
265+
[
266+
transforms.RandomRotation((-45, 45)),
267+
transforms.RandomHorizontalFlip(),
268+
transforms.RandomVerticalFlip(),
269+
], p=0.75
270+
)
271+
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
272+
num_samples = 250
273+
num_applies = 0
274+
for _ in range(num_samples):
275+
out = random_apply_transform(img)
276+
if out != img:
277+
num_applies += 1
278+
279+
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
280+
random.setstate(random_state)
281+
assert p_value > 0.0001
282+
283+
# Checking if RandomApply can be printed as string
284+
random_apply_transform.__repr__()
285+
286+
def test_random_choice(self):
287+
random_state = random.getstate()
288+
random.seed(42)
289+
random_choice_transform = transforms.RandomChoice(
290+
[
291+
transforms.Resize(15),
292+
transforms.Resize(20),
293+
transforms.CenterCrop(10)
294+
]
295+
)
296+
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
297+
num_samples = 250
298+
num_resize_15 = 0
299+
num_resize_20 = 0
300+
num_crop_10 = 0
301+
for _ in range(num_samples):
302+
out = random_choice_transform(img)
303+
if out.size == (15, 15):
304+
num_resize_15 += 1
305+
elif out.size == (20, 20):
306+
num_resize_20 += 1
307+
elif out.size == (10, 10):
308+
num_crop_10 += 1
309+
310+
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
311+
assert p_value > 0.0001
312+
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
313+
assert p_value > 0.0001
314+
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
315+
assert p_value > 0.0001
316+
317+
random.setstate(random_state)
318+
# Checking if RandomChoice can be printed as string
319+
random_choice_transform.__repr__()
320+
321+
def test_random_order(self):
322+
random_state = random.getstate()
323+
random.seed(42)
324+
random_order_transform = transforms.RandomOrder(
325+
[
326+
transforms.Resize(20),
327+
transforms.CenterCrop(10)
328+
]
329+
)
330+
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
331+
num_samples = 250
332+
num_normal_order = 0
333+
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
334+
for _ in range(num_samples):
335+
out = random_order_transform(img)
336+
if out == resize_crop_out:
337+
num_normal_order += 1
338+
339+
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
340+
random.setstate(random_state)
341+
assert p_value > 0.0001
342+
343+
# Checking if RandomOrder can be printed as string
344+
random_order_transform.__repr__()
345+
261346
def test_to_tensor(self):
262347
test_channels = [1, 3, 4]
263348
height, width = 4, 4

torchvision/transforms/transforms.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from . import functional as F
1717

1818
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
19-
"Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
20-
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation",
21-
"Grayscale", "RandomGrayscale"]
19+
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
20+
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
21+
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
2222

2323

2424
class Compose(object):
@@ -261,6 +261,77 @@ def __repr__(self):
261261
return self.__class__.__name__ + '()'
262262

263263

264+
class RandomTransforms(object):
265+
"""Base class for a list of transformations with randomness
266+
267+
Args:
268+
transforms (list or tuple): list of transformations
269+
"""
270+
271+
def __init__(self, transforms):
272+
assert isinstance(transforms, (list, tuple))
273+
self.transforms = transforms
274+
275+
def __call__(self, *args, **kwargs):
276+
raise NotImplementedError()
277+
278+
def __repr__(self):
279+
format_string = self.__class__.__name__ + '('
280+
for t in self.transforms:
281+
format_string += '\n'
282+
format_string += ' {0}'.format(t)
283+
format_string += '\n)'
284+
return format_string
285+
286+
287+
class RandomApply(RandomTransforms):
288+
"""Apply randomly a list of transformations with a given probability
289+
290+
Args:
291+
transforms (list or tuple): list of transformations
292+
p (float): probability
293+
"""
294+
295+
def __init__(self, transforms, p=0.5):
296+
super(RandomApply, self).__init__(transforms)
297+
self.p = p
298+
299+
def __call__(self, img):
300+
if self.p < random.random():
301+
return img
302+
for t in self.transforms:
303+
img = t(img)
304+
return img
305+
306+
def __repr__(self):
307+
format_string = self.__class__.__name__ + '('
308+
format_string += '\n p={}'.format(self.p)
309+
for t in self.transforms:
310+
format_string += '\n'
311+
format_string += ' {0}'.format(t)
312+
format_string += '\n)'
313+
return format_string
314+
315+
316+
class RandomOrder(RandomTransforms):
317+
"""Apply a list of transformations in a random order
318+
"""
319+
def __call__(self, img):
320+
order = list(range(len(self.transforms)))
321+
random.shuffle(order)
322+
for i in order:
323+
img = self.transforms[i](img)
324+
return img
325+
326+
327+
class RandomChoice(RandomTransforms):
328+
"""Apply single transformation randomly picked from a list
329+
"""
330+
def __call__(self, img):
331+
t = random.choice(self.transforms)
332+
return t(img)
333+
334+
264335
class RandomCrop(object):
265336
"""Crop the given PIL Image at a random location.
266337

0 commit comments

Comments
 (0)