diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 09161d506de..981fff33b43 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -162,19 +162,36 @@ def __call__(self, img): return self.lambd(img) +class RandomCropGenerator(object): + def __init__(self): + self.x1 = 0 + self.y1 = 0 + + def generate(self): + self.x1 = random.random() + self.y1 = random.random() + + class RandomCrop(object): """Crops the given PIL.Image at a random location to have a region of the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) """ - def __init__(self, size, padding=0): + def __init__(self, size, padding=0, generator=None): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding + self._generator = generator + self._need_generation = False + if generator is None: + self._generator = RandomCropGenerator() + self._need_generation = True + + def __call__(self, img): if self.padding > 0: img = ImageOps.expand(img, border=self.padding, fill=0) @@ -184,17 +201,35 @@ def __call__(self, img): if w == tw and h == th: return img - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) + if self._need_generation: + self._generator.generate() + + x1 = math.floor(self._generator.x1 * (w - tw)) + y1 = math.floor(self._generator.y1 * (h - th)) + return img.crop((x1, y1, x1 + tw, y1 + th)) +class RandomFlipGenerator(object): + def generate(self): + self.flip = random.random() < 0.5 + + class RandomHorizontalFlip(object): """Randomly horizontally flips the given PIL.Image with a probability of 0.5 """ + def __init__(self, generator=None): + self._generator = generator + self._need_generation = False + if generator is None: + self._generator = RandomFlipGenerator() + self._need_generation = True def __call__(self, img): - if random.random() < 0.5: + if self._need_generation: + self._generator.generate() + + if self._generator.flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img