Skip to content

Commit 5985869

Browse files
vfdev-5fmassa
authored andcommitted
Probability parameter in RandomHorizontalFlip, RandomHorizontalFlip (#417)
* Set probability as configuration parameter in RandomHorizontalFlip and RandomHorizontalFlip (#414) * Fix documentation
1 parent 22385bc commit 5985869

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

test/test_transforms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,17 @@ def test_random_vertical_flip(self):
475475
random.setstate(random_state)
476476
assert p_value > 0.0001
477477

478+
num_samples = 250
479+
num_vertical = 0
480+
for _ in range(num_samples):
481+
out = transforms.RandomVerticalFlip(p=0.7)(img)
482+
if out == vimg:
483+
num_vertical += 1
484+
485+
p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
486+
random.setstate(random_state)
487+
assert p_value > 0.0001
488+
478489
# Checking if RandomVerticalFlip can be printed as string
479490
transforms.RandomVerticalFlip().__repr__()
480491

@@ -496,6 +507,17 @@ def test_random_horizontal_flip(self):
496507
random.setstate(random_state)
497508
assert p_value > 0.0001
498509

510+
num_samples = 250
511+
num_horizontal = 0
512+
for _ in range(num_samples):
513+
out = transforms.RandomHorizontalFlip(p=0.7)(img)
514+
if out == himg:
515+
num_horizontal += 1
516+
517+
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
518+
random.setstate(random_state)
519+
assert p_value > 0.0001
520+
499521
# Checking if RandomHorizontalFlip can be printed as string
500522
transforms.RandomHorizontalFlip().__repr__()
501523

torchvision/transforms/transforms.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,14 @@ def __repr__(self):
321321

322322

323323
class RandomHorizontalFlip(object):
324-
"""Horizontally flip the given PIL Image randomly with a probability of 0.5."""
324+
"""Horizontally flip the given PIL Image randomly with a given probability.
325+
326+
Args:
327+
p (float): probability of the image being flipped. Default value is 0.5
328+
"""
329+
330+
def __init__(self, p=0.5):
331+
self.p = p
325332

326333
def __call__(self, img):
327334
"""
@@ -331,16 +338,23 @@ def __call__(self, img):
331338
Returns:
332339
PIL Image: Randomly flipped image.
333340
"""
334-
if random.random() < 0.5:
341+
if random.random() < self.p:
335342
return F.hflip(img)
336343
return img
337344

338345
def __repr__(self):
339-
return self.__class__.__name__ + '()'
346+
return self.__class__.__name__ + '(p={})'.format(self.p)
340347

341348

342349
class RandomVerticalFlip(object):
343-
"""Vertically flip the given PIL Image randomly with a probability of 0.5."""
350+
"""Vertically flip the given PIL Image randomly with a given probability.
351+
352+
Args:
353+
p (float): probability of the image being flipped. Default value is 0.5
354+
"""
355+
356+
def __init__(self, p=0.5):
357+
self.p = p
344358

345359
def __call__(self, img):
346360
"""
@@ -350,12 +364,12 @@ def __call__(self, img):
350364
Returns:
351365
PIL Image: Randomly flipped image.
352366
"""
353-
if random.random() < 0.5:
367+
if random.random() < self.p:
354368
return F.vflip(img)
355369
return img
356370

357371
def __repr__(self):
358-
return self.__class__.__name__ + '()'
372+
return self.__class__.__name__ + '(p={})'.format(self.p)
359373

360374

361375
class RandomResizedCrop(object):

0 commit comments

Comments
 (0)