diff --git a/test/test_transforms.py b/test/test_transforms.py index 2968282cfc1..35ee2329590 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,48 +23,45 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') +def transform_helper(t, is_pil=True): + t = [t] + if is_pil: + t.insert(0, transforms.ToPILImage()) + t.append(transforms.ToTensor()) + return transforms.Compose(t) + + class Tester(unittest.TestCase): def test_crop(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 - oheight = random.randint(5, (height - 2) / 2) * 2 - owidth = random.randint(5, (width - 2) / 2) * 2 - img = torch.ones(3, height, width) - oh1 = (height - oheight) // 2 - ow1 = (width - owidth) // 2 - imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] - imgnarrow.fill_(0) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - assert result.sum() == 0, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - oheight += 1 - owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - sum1 = result.sum() - assert sum1 > 1, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - oheight += 1 - owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - sum2 = result.sum() - assert sum2 > 0, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - assert sum2 > sum1, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + for is_pil in [True, False]: + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + img = torch.ones(3, height, width) + oh1 = (height - oheight) // 2 + ow1 = (width - owidth) // 2 + imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] + imgnarrow.fill_(0) + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + assert result.sum() == 0, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + sum1 = result.sum() + assert sum1 > 1, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + sum2 = result.sum() + assert sum2 > 0, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + assert sum2 > sum1, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) def test_five_crop(self): to_pil_image = transforms.ToPILImage() @@ -87,7 +84,6 @@ def test_five_crop(self): for crop in results: assert crop.size == (crop_w, crop_h) - to_pil_image = transforms.ToPILImage() tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) @@ -175,46 +171,37 @@ def test_randomperspective(self): def test_resize(self): height = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2 - osize = random.randint(5, 12) * 2 - - img = torch.ones(3, height, width) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(osize), - transforms.ToTensor(), - ])(img) - assert osize in result.size() - if height < width: - assert result.size(1) <= result.size(2) - elif width < height: - assert result.size(1) >= result.size(2) - - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize([osize, osize]), - transforms.ToTensor(), - ])(img) - assert osize in result.size() - assert result.size(1) == osize - assert result.size(2) == osize + osize = random.randint(5, 12) * 2 oheight = random.randint(5, 12) * 2 owidth = random.randint(5, 12) * 2 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize((oheight, owidth)), - transforms.ToTensor(), - ])(img) - assert result.size(1) == oheight - assert result.size(2) == owidth - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize([oheight, owidth]), - transforms.ToTensor(), - ])(img) - assert result.size(1) == oheight - assert result.size(2) == owidth + img = torch.rand(3, height, width) + + for is_pil in [True, False]: + result = transform_helper(transforms.Resize(osize), is_pil)(img) + self.assertIn(osize, result.size()) + if height < width: + self.assertTrue(result.size(1) <= result.size(2)) + elif width < height: + self.assertTrue(result.size(1) >= result.size(2)) + + for size in [[osize, osize], (oheight, owidth), [oheight, owidth]]: + result = transform_helper(transforms.Resize(size), is_pil)(img) + self.assertTrue(result.size(1) == size[0]) + self.assertTrue(result.size(2) == size[1]) + + # test resize on 3d and 4d images for tensor inputs + t = transform_helper(transforms.Resize((oheight, owidth)), is_pil=False) + img = torch.rand(3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (3, oheight, owidth)) + img = torch.rand(1, 3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (1, 3, oheight, owidth)) + img = torch.rand(2, 3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (2, 3, oheight, owidth)) def test_random_crop(self): height = random.randint(10, 32) * 2 @@ -737,6 +724,27 @@ def test_ndarray_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + def _test_flip(self, method): + img = torch.rand(3, 10, 10) + pil_img = transforms.functional.to_pil_image(img) + + func = getattr(transforms.functional, method) + + f_img = func(img) + f_pil_img = func(pil_img) + f_pil_img = transforms.functional.to_tensor(f_pil_img) + # there are rounding differences with PIL due to uint8 conversion + self.assertTrue((f_img - f_pil_img).abs().max() < 1.0 / 255) + + ff_img = func(f_img) + self.assertTrue(img.equal(ff_img)) + + def test_vertical_flip(self): + self._test_flip('vflip') + + def test_horizontal_flip(self): + self._test_flip('hflip') + @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_vertical_flip(self): random_state = random.getstate() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 14448e01b00..2bd20c5d1b3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -31,6 +31,15 @@ def _is_tensor_image(img): return torch.is_tensor(img) and img.ndimension() == 3 +def _get_image_size(img): + if _is_pil_image(img): + return img.size + elif isinstance(img, torch.Tensor) and img.dim() > 2: + return img.shape[-2:][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + def _is_numpy(img): return isinstance(img, np.ndarray) @@ -234,26 +243,42 @@ def resize(img, size, interpolation=Image.BILINEAR): Returns: PIL Image: Resized image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) if isinstance(size, int): - w, h = img.size + w, h = _get_image_size(img) if (w <= h and w == size) or (h <= w and h == size): return img if w < h: ow = size oh = int(size * h / w) - return img.resize((ow, oh), interpolation) else: oh = size ow = int(size * w / h) - return img.resize((ow, oh), interpolation) - else: + size = (oh, ow) + if _is_pil_image(img): return img.resize(size[::-1], interpolation) + # tensor codepath + # TODO maybe move this outside + _PIL_TO_TORCH_INTERP_MODE = { + Image.NEAREST: "nearest", + Image.BILINEAR: "bilinear" + } + should_unsqueeze = False + if img.dim() == 3: + img = img[None] + should_unsqueeze = True + out = torch.nn.functional.interpolate(img, size=size, + mode=_PIL_TO_TORCH_INTERP_MODE[interpolation], + align_corners=False) + if should_unsqueeze: + out = out[0] + return out + def scale(*args, **kwargs): warnings.warn("The use of the transforms.Scale transform is deprecated, " + @@ -362,16 +387,19 @@ def crop(img, i, j, h, w): Returns: PIL Image: Cropped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((j, i, j + w, i + h)) + if _is_pil_image(img): + return img.crop((j, i, j + w, i + h)) + + return img[..., i:(i + h), j:(j + w)] def center_crop(img, output_size): if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) - w, h = img.size + w, h = _get_image_size(img) th, tw = output_size i = int(round((h - th) / 2.)) j = int(round((w - tw) / 2.)) @@ -410,10 +438,13 @@ def hflip(img): Returns: PIL Image: Horizontall flipped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.transpose(Image.FLIP_LEFT_RIGHT) + if _is_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + return img.flip(dims=(-1,)) def _get_perspective_coeffs(startpoints, endpoints): @@ -468,10 +499,13 @@ def vflip(img): Returns: PIL Image: Vertically flipped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.transpose(Image.FLIP_TOP_BOTTOM) + if _is_pil_image(img): + return img.transpose(Image.FLIP_TOP_BOTTOM) + + return img.flip(dims=(-2,)) def five_crop(img, size):