From f9788f3ea786815422617d961dd76eaf73796c68 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 3 Jul 2019 09:50:00 -0700 Subject: [PATCH 1/8] [WIP] Add tensor transform for resize --- torchvision/transforms/functional.py | 35 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 14448e01b00..1faac857add 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -31,6 +31,15 @@ def _is_tensor_image(img): return torch.is_tensor(img) and img.ndimension() == 3 +def _get_image_size(img): + if _is_pil_image(img): + return img.size + elif isinstance(img, torch.Tensor) and img.dim() > 2: + return img.shape[-2:][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + def _is_numpy(img): return isinstance(img, np.ndarray) @@ -234,26 +243,42 @@ def resize(img, size, interpolation=Image.BILINEAR): Returns: PIL Image: Resized image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) if isinstance(size, int): - w, h = img.size + w, h = _get_image_size(img) if (w <= h and w == size) or (h <= w and h == size): return img if w < h: ow = size oh = int(size * h / w) - return img.resize((ow, oh), interpolation) else: oh = size ow = int(size * w / h) - return img.resize((ow, oh), interpolation) - else: + size = (oh, ow) + if _is_pil_image(img): return img.resize(size[::-1], interpolation) + # tensor codepath + # TODO maybe move this outside + _PIL_TO_TORCH_INTERP_MODE = { + Image.NEAREST: "nearest", + Image.BILINEAR: "bilinear" + } + should_unsqueeze = False + if img.dim() == 3: + img = img[None] + should_unsqueeze = True + out = torch.nn.functional.interpolate(img, size=size, + mode=_PIL_TO_TORCH_INTERP_MODE[interpolation], + align_corners=False) + if should_unsqueeze: + out = out[0] + return out + def scale(*args, **kwargs): warnings.warn("The use of the transforms.Scale transform is deprecated, " + From f8a5a17f46ab937dff08a1fa911f6af581a2d032 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 9 Jul 2019 04:48:22 -0700 Subject: [PATCH 2/8] WIP: start improving resize tests --- test/test_transforms.py | 57 ++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2968282cfc1..5762646b0a6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -173,48 +173,35 @@ def test_randomperspective(self): torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)) def test_resize(self): + def resize_helper(size, is_pil=True): + t = [transforms.Resize(size)] + if is_pil: + t.insert(0, transforms.ToPILImage()) + t.append(transforms.ToTensor()) + return transforms.Compose(t) + height = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2 + osize = random.randint(5, 12) * 2 + oheight = random.randint(5, 12) * 2 + owidth = random.randint(5, 12) * 2 - img = torch.ones(3, height, width) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(osize), - transforms.ToTensor(), - ])(img) - assert osize in result.size() - if height < width: - assert result.size(1) <= result.size(2) - elif width < height: - assert result.size(1) >= result.size(2) + img = torch.rand(3, height, width) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize([osize, osize]), - transforms.ToTensor(), - ])(img) - assert osize in result.size() - assert result.size(1) == osize - assert result.size(2) == osize + for is_pil in [True, False]: + result = resize_helper(osize, is_pil)(img) + assert osize in result.size() + if height < width: + assert result.size(1) <= result.size(2) + elif width < height: + assert result.size(1) >= result.size(2) - oheight = random.randint(5, 12) * 2 - owidth = random.randint(5, 12) * 2 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize((oheight, owidth)), - transforms.ToTensor(), - ])(img) - assert result.size(1) == oheight - assert result.size(2) == owidth + for size in [[osize, osize], (oheight, owidth), [oheight, owidth]]: + result = resize_helper(size, is_pil)(img) + assert result.size(1) == size[0] + assert result.size(2) == size[1] - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize([oheight, owidth]), - transforms.ToTensor(), - ])(img) - assert result.size(1) == oheight - assert result.size(2) == owidth def test_random_crop(self): height = random.randint(10, 32) * 2 From b13c0208ff89cfe62cbcd897dd3b3e7580299f84 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 10 Jul 2019 08:34:00 -0700 Subject: [PATCH 3/8] Add more tests --- test/test_transforms.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5762646b0a6..b090d5aa0b3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -191,16 +191,28 @@ def resize_helper(size, is_pil=True): for is_pil in [True, False]: result = resize_helper(osize, is_pil)(img) - assert osize in result.size() + self.assertIn(osize, result.size()) if height < width: - assert result.size(1) <= result.size(2) + self.assertTrue(result.size(1) <= result.size(2)) elif width < height: - assert result.size(1) >= result.size(2) + self.assertTrue(result.size(1) >= result.size(2)) for size in [[osize, osize], (oheight, owidth), [oheight, owidth]]: result = resize_helper(size, is_pil)(img) - assert result.size(1) == size[0] - assert result.size(2) == size[1] + self.assertTrue(result.size(1) == size[0]) + self.assertTrue(result.size(2) == size[1]) + + # test resize on 3d and 4d images for tensor inputs + t = resize_helper((oheight, owidth), is_pil=False) + img = torch.rand(3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (3, oheight, owidth)) + img = torch.rand(1, 3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (1, 3, oheight, owidth)) + img = torch.rand(2, 3, height, width) + r = t(img) + self.assertEqual(tuple(r.shape), (2, 3, oheight, owidth)) def test_random_crop(self): From 68aa2586d65b1384c7654b619cae0f8596254b77 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 11 Jul 2019 02:35:45 -0700 Subject: [PATCH 4/8] Fix lint --- test/test_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index b090d5aa0b3..617a90cd4d0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -214,7 +214,6 @@ def resize_helper(size, is_pil=True): r = t(img) self.assertEqual(tuple(r.shape), (2, 3, oheight, owidth)) - def test_random_crop(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 From 7062c78dbc2ab4dfba9c65d57ce0ed3eac56e811 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 11 Jul 2019 06:32:13 -0700 Subject: [PATCH 5/8] Add tensor crop --- test/test_transforms.py | 14 +++++++++++--- torchvision/transforms/functional.py | 9 ++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 617a90cd4d0..fb59ef29cab 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,6 +23,14 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') +def transform_helper(t, is_pil=True): + t = [t] + if is_pil: + t.insert(0, transforms.ToPILImage()) + t.append(transforms.ToTensor()) + return transforms.Compose(t) + + class Tester(unittest.TestCase): def test_crop(self): @@ -190,7 +198,7 @@ def resize_helper(size, is_pil=True): img = torch.rand(3, height, width) for is_pil in [True, False]: - result = resize_helper(osize, is_pil)(img) + result = transform_helper(transforms.Resize(osize), is_pil)(img) self.assertIn(osize, result.size()) if height < width: self.assertTrue(result.size(1) <= result.size(2)) @@ -198,12 +206,12 @@ def resize_helper(size, is_pil=True): self.assertTrue(result.size(1) >= result.size(2)) for size in [[osize, osize], (oheight, owidth), [oheight, owidth]]: - result = resize_helper(size, is_pil)(img) + result = transform_helper(transforms.Resize(size), is_pil)(img) self.assertTrue(result.size(1) == size[0]) self.assertTrue(result.size(2) == size[1]) # test resize on 3d and 4d images for tensor inputs - t = resize_helper((oheight, owidth), is_pil=False) + t = transform_helper(transforms.Resize((oheight, owidth)), is_pil=False) img = torch.rand(3, height, width) r = t(img) self.assertEqual(tuple(r.shape), (3, oheight, owidth)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 1faac857add..8fdbf013c52 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -387,16 +387,19 @@ def crop(img, i, j, h, w): Returns: PIL Image: Cropped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((j, i, j + w, i + h)) + if _is_pil_image(img): + return img.crop((j, i, j + w, i + h)) + + return im[..., i:(i + h), j:(j + w)] def center_crop(img, output_size): if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) - w, h = img.size + w, h = _get_image_size(img) th, tw = output_size i = int(round((h - th) / 2.)) j = int(round((w - tw) / 2.)) From 6c0b75ee96ce3c631d86a97ebdb2ee54d1a44217 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 11 Jul 2019 06:50:21 -0700 Subject: [PATCH 6/8] Add tests for crop and bugfix --- test/test_transforms.py | 68 ++++++++++------------------ torchvision/transforms/functional.py | 2 +- 2 files changed, 26 insertions(+), 44 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fb59ef29cab..ee5e1f82e34 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -36,43 +36,32 @@ class Tester(unittest.TestCase): def test_crop(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 - oheight = random.randint(5, (height - 2) / 2) * 2 - owidth = random.randint(5, (width - 2) / 2) * 2 - img = torch.ones(3, height, width) - oh1 = (height - oheight) // 2 - ow1 = (width - owidth) // 2 - imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] - imgnarrow.fill_(0) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - assert result.sum() == 0, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - oheight += 1 - owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - sum1 = result.sum() - assert sum1 > 1, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - oheight += 1 - owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) - sum2 = result.sum() - assert sum2 > 0, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) - assert sum2 > sum1, "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + for is_pil in [True, False]: + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + img = torch.ones(3, height, width) + oh1 = (height - oheight) // 2 + ow1 = (width - owidth) // 2 + imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] + imgnarrow.fill_(0) + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + assert result.sum() == 0, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + sum1 = result.sum() + assert sum1 > 1, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + oheight += 1 + owidth += 1 + result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img) + sum2 = result.sum() + assert sum2 > 0, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + assert sum2 > sum1, "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) def test_five_crop(self): to_pil_image = transforms.ToPILImage() @@ -181,13 +170,6 @@ def test_randomperspective(self): torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)) def test_resize(self): - def resize_helper(size, is_pil=True): - t = [transforms.Resize(size)] - if is_pil: - t.insert(0, transforms.ToPILImage()) - t.append(transforms.ToTensor()) - return transforms.Compose(t) - height = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8fdbf013c52..75ed293186a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -393,7 +393,7 @@ def crop(img, i, j, h, w): if _is_pil_image(img): return img.crop((j, i, j + w, i + h)) - return im[..., i:(i + h), j:(j + w)] + return img[..., i:(i + h), j:(j + w)] def center_crop(img, output_size): From 934a2578c872bc44a95cfae451a37cd01e009c62 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 12 Jul 2019 01:58:11 -0700 Subject: [PATCH 7/8] Flip supports tensors --- test/test_transforms.py | 1 - torchvision/transforms/functional.py | 14 ++++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ee5e1f82e34..7b0a378db5f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -84,7 +84,6 @@ def test_five_crop(self): for crop in results: assert crop.size == (crop_w, crop_h) - to_pil_image = transforms.ToPILImage() tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 75ed293186a..2bd20c5d1b3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -438,10 +438,13 @@ def hflip(img): Returns: PIL Image: Horizontall flipped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.transpose(Image.FLIP_LEFT_RIGHT) + if _is_pil_image(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + return img.flip(dims=(-1,)) def _get_perspective_coeffs(startpoints, endpoints): @@ -496,10 +499,13 @@ def vflip(img): Returns: PIL Image: Vertically flipped image. """ - if not _is_pil_image(img): + if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.transpose(Image.FLIP_TOP_BOTTOM) + if _is_pil_image(img): + return img.transpose(Image.FLIP_TOP_BOTTOM) + + return img.flip(dims=(-2,)) def five_crop(img, size): From ed3424fed6e8155575c5da59dc474d45c3c35eea Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 12 Jul 2019 03:00:37 -0700 Subject: [PATCH 8/8] Add tests for [h/v]flip --- test/test_transforms.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7b0a378db5f..35ee2329590 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -724,6 +724,27 @@ def test_ndarray_bad_types_to_pil_image(self): with self.assertRaises(ValueError): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) + def _test_flip(self, method): + img = torch.rand(3, 10, 10) + pil_img = transforms.functional.to_pil_image(img) + + func = getattr(transforms.functional, method) + + f_img = func(img) + f_pil_img = func(pil_img) + f_pil_img = transforms.functional.to_tensor(f_pil_img) + # there are rounding differences with PIL due to uint8 conversion + self.assertTrue((f_img - f_pil_img).abs().max() < 1.0 / 255) + + ff_img = func(f_img) + self.assertTrue(img.equal(ff_img)) + + def test_vertical_flip(self): + self._test_flip('vflip') + + def test_horizontal_flip(self): + self._test_flip('hflip') + @unittest.skipIf(stats is None, 'scipy.stats not available') def test_random_vertical_flip(self): random_state = random.getstate()