Skip to content

[BC-breaking] Unified inputs for grayscale ops and transforms #2586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 28, 2020
Merged
9 changes: 6 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,12 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
err = getattr(torch, method)(tensor - pil_tensor).item()
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
err = getattr(torch, agg_method)(tensor - pil_tensor).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
Expand Down
33 changes: 22 additions & 11 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,29 @@ def test_adjustments(self):
def test_adjustments_cuda(self):
self._test_adjustments("cuda")

def _test_rgb_to_grayscale(self, device):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

img_tensor, pil_img = self._create_data(32, 34, device=device)

for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

if num_output_channels == 1:
print(gray_tensor.shape)

self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
self._test_rgb_to_grayscale("cpu")

@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_rgb_to_grayscale_cuda(self):
self._test_rgb_to_grayscale("cuda")

def _test_center_crop(self, device):
script_center_crop = torch.jit.script(F.center_crop)
Expand Down
71 changes: 46 additions & 25 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

class Tester(TransformsTester):

def _test_functional_geom_op(self, func, fn_kwargs):
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

def _test_class_geom_op(self, method, meth_kwargs=None):
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}

Expand All @@ -35,21 +35,24 @@ def _test_class_geom_op(self, method, meth_kwargs=None):
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)

torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))

def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)

def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip')
self._test_op('hflip', 'RandomHorizontalFlip')

def test_random_vertical_flip(self):
self._test_geom_op('vflip', 'RandomVerticalFlip')
self._test_op('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
Expand Down Expand Up @@ -80,30 +83,30 @@ def test_adjustments(self):
def test_pad(self):

# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_geom_op(
self._test_functional_op(
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand All @@ -120,17 +123,17 @@ def test_crop(self):
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_geom_op("RandomCrop", config)
self._test_class_op("RandomCrop", config)

def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
Expand All @@ -149,7 +152,7 @@ def test_center_crop(self):
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
Expand Down Expand Up @@ -178,37 +181,37 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me

def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

Expand Down Expand Up @@ -312,6 +315,24 @@ def test_random_perspective(self):
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))

def test_to_grayscale(self):

meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)


if __name__ == '__main__':
unittest.main()
54 changes: 39 additions & 15 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]:
return F_pil._get_image_size(img)


def _get_image_num_channels(img: Tensor) -> int:
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)

return F_pil._get_image_num_channels(img)


@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
Expand Down Expand Up @@ -951,32 +958,49 @@ def affine(
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)


@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.

Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)

return img
raise TypeError("Input should be PIL Image")


def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.

Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)

return F_t.rgb_to_grayscale(img, num_output_channels)


def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
Expand Down
37 changes: 37 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Expand Down Expand Up @@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
opts = _parse_fill(fill, img, '5.0.0')

return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)


@torch.jit.unused
def to_grayscale(img, num_output_channels):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.

Args:
img (PIL Image): Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.

Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel

if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')

return img
Loading