diff --git a/test/test_transforms.py b/test/test_transforms.py index cb458f6e4fe..a53f14d4d92 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -102,10 +102,12 @@ def test_ten_crop(self): if single_dim: crop_h = min(crop_h, crop_w) crop_w = crop_h - transform = transforms.TenCrop(crop_h, vflip=should_vflip) + transform = transforms.TenCrop(crop_h, + vertical_flip=should_vflip) five_crop = transforms.FiveCrop(crop_h) else: - transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip) + transform = transforms.TenCrop((crop_h, crop_w), + vertical_flip=should_vflip) five_crop = transforms.FiveCrop((crop_h, crop_w)) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 27fe4e2d1bc..deff17115d2 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -281,6 +281,73 @@ def vflip(img): return img.transpose(Image.FLIP_TOP_BOTTOM) +def five_crop(img, size): + """Crop the given PIL.Image into four corners and the central crop. + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + Returns: + tuple: tuple (tl, tr, bl, br, center) corresponding top left, + top right, bottom left, bottom right and center crop. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + w, h = img.size + crop_h, crop_w = size + if crop_w > w or crop_h > h: + raise ValueError("Requested crop size {} is bigger than input size {}".format(size, + (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = CenterCrop((crop_h, crop_w))(img) + return (tl, tr, bl, br, center) + + +def ten_crop(img, size, vertical_flip=False): + """Crop the given PIL.Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, + br_flip, center_flip) corresponding top left, top right, + bottom left, bottom right and center crop and same for the + flipped image. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + class Compose(object): """Composes several transforms together. @@ -661,17 +728,7 @@ def __init__(self, size): self.size = size def __call__(self, img): - w, h = img.size - crop_h, crop_w = self.size - if crop_w > w or crop_h > h: - raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size, - (h, w))) - tl = img.crop((0, 0, crop_w, crop_h)) - tr = img.crop((w - crop_w, 0, w, crop_h)) - bl = img.crop((0, h - crop_h, crop_w, h)) - br = img.crop((w - crop_w, h - crop_h, w, h)) - center = CenterCrop((crop_h, crop_w))(img) - return (tl, tr, bl, br, center) + return five_crop(img, self.size) class TenCrop(object): @@ -685,25 +742,17 @@ class TenCrop(object): size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. - vflip bool: Use vertical flipping instead of horizontal + vertical_flip(bool): Use vertical flipping instead of horizontal """ - def __init__(self, size, vflip=False): + def __init__(self, size, vertical_flip=False): self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size - self.vflip = vflip + self.vertical_flip = vertical_flip def __call__(self, img): - five_crop = FiveCrop(self.size) - first_five = five_crop(img) - if self.vflip: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - else: - img = img.transpose(Image.FLIP_LEFT_RIGHT) - - second_five = five_crop(img) - return first_five + second_five + return ten_crop(img, self.size, self.vertical_flip)