diff --git a/test/test_transforms.py b/test/test_transforms.py index 9319fb0664a..cb458f6e4fe 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -61,6 +61,66 @@ def test_crop(self): assert sum2 > sum1, "height: " + str(height) + " width: " \ + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + def test_five_crop(self): + to_pil_image = transforms.ToPILImage() + h = random.randint(5, 25) + w = random.randint(5, 25) + for single_dim in [True, False]: + crop_h = random.randint(1, h) + crop_w = random.randint(1, w) + if single_dim: + crop_h = min(crop_h, crop_w) + crop_w = crop_h + transform = transforms.FiveCrop(crop_h) + else: + transform = transforms.FiveCrop((crop_h, crop_w)) + + img = torch.FloatTensor(3, h, w).uniform_() + results = transform(to_pil_image(img)) + + assert len(results) == 5 + for crop in results: + assert crop.size == (crop_w, crop_h) + + to_pil_image = transforms.ToPILImage() + tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) + tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) + bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) + br = to_pil_image(img[:, h - crop_h:, w - crop_w:]) + center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img)) + expected_output = (tl, tr, bl, br, center) + assert results == expected_output + + def test_ten_crop(self): + to_pil_image = transforms.ToPILImage() + h = random.randint(5, 25) + w = random.randint(5, 25) + for should_vflip in [True, False]: + for single_dim in [True, False]: + crop_h = random.randint(1, h) + crop_w = random.randint(1, w) + if single_dim: + crop_h = min(crop_h, crop_w) + crop_w = crop_h + transform = transforms.TenCrop(crop_h, vflip=should_vflip) + five_crop = transforms.FiveCrop(crop_h) + else: + transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip) + five_crop = transforms.FiveCrop((crop_h, crop_w)) + + img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) + results = transform(img) + expected_output = five_crop(img) + if should_vflip: + vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM) + expected_output += five_crop(vflipped_img) + else: + hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT) + expected_output += five_crop(hflipped_img) + + assert len(results) == 10 + assert expected_output == results + def test_scale(self): height = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2 diff --git a/torchvision/transforms.py b/torchvision/transforms.py index c8a2911e74e..27fe4e2d1bc 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -638,3 +638,72 @@ def __call__(self, img): """ i, j, h, w = self.get_params(img) return scaled_crop(img, i, j, h, w, self.size, self.interpolation) + + +class FiveCrop(object): + """Crop the given PIL.Image into four corners and the central crop.abs + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + + def __call__(self, img): + w, h = img.size + crop_h, crop_w = self.size + if crop_w > w or crop_h > h: + raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size, + (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = CenterCrop((crop_h, crop_w))(img) + return (tl, tr, bl, br, center) + + +class TenCrop(object): + """Crop the given PIL.Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default) + + Note: this transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your `Dataset` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vflip bool: Use vertical flipping instead of horizontal + """ + + def __init__(self, size, vflip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + self.vflip = vflip + + def __call__(self, img): + five_crop = FiveCrop(self.size) + first_five = five_crop(img) + if self.vflip: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + else: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + second_five = five_crop(img) + return first_five + second_five