Skip to content

Commit d509156

Browse files
authored
Undeprecate ToGrayScale transforms and functionals (#7122)
1 parent 60c78f2 commit d509156

File tree

11 files changed

+111
-124
lines changed

11 files changed

+111
-124
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def __init__(
149149
ArgsKwargs(num_output_channels=3),
150150
],
151151
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
152+
# Use default tolerances of `torch.testing.assert_close`
153+
closeness_kwargs=dict(rtol=None, atol=None),
152154
),
153155
ConsistencyConfig(
154156
prototype_transforms.ConvertDtype,
@@ -271,6 +273,9 @@ def __init__(
271273
ArgsKwargs(p=0),
272274
ArgsKwargs(p=1),
273275
],
276+
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
277+
# Use default tolerances of `torch.testing.assert_close`
278+
closeness_kwargs=dict(rtol=None, atol=None),
274279
),
275280
ConsistencyConfig(
276281
prototype_transforms.RandomResizedCrop,

torchvision/prototype/datapoints/_datapoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def elastic(
230230
) -> Datapoint:
231231
return self
232232

233+
def to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
234+
return self
235+
233236
def adjust_brightness(self, brightness_factor: float) -> Datapoint:
234237
return self
235238

torchvision/prototype/datapoints/_image.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def elastic(
169169
)
170170
return Image.wrap_like(self, output)
171171

172+
def to_grayscale(self, num_output_channels: int = 1) -> Image:
173+
output = self._F.rgb_to_grayscale_image_tensor(
174+
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
175+
)
176+
return Image.wrap_like(self, output)
177+
172178
def adjust_brightness(self, brightness_factor: float) -> Image:
173179
output = self._F.adjust_brightness_image_tensor(
174180
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor

torchvision/prototype/datapoints/_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ def elastic(
173173
)
174174
return Video.wrap_like(self, output)
175175

176+
def to_grayscale(self, num_output_channels: int = 1) -> Video:
177+
output = self._F.rgb_to_grayscale_image_tensor(
178+
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
179+
)
180+
return Video.wrap_like(self, output)
181+
176182
def adjust_brightness(self, brightness_factor: float) -> Video:
177183
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
178184
return Video.wrap_like(self, output)

torchvision/prototype/transforms/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
1010
from ._color import (
1111
ColorJitter,
12+
Grayscale,
1213
RandomAdjustSharpness,
1314
RandomAutocontrast,
1415
RandomEqualize,
16+
RandomGrayscale,
1517
RandomInvert,
1618
RandomPhotometricDistort,
1719
RandomPosterize,
@@ -54,4 +56,4 @@
5456
from ._temporal import UniformTemporalSubsample
5557
from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
5658

57-
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
59+
from ._deprecated import ToTensor # usort: skip

torchvision/prototype/transforms/_color.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,41 @@
1111
from .utils import is_simple_tensor, query_chw
1212

1313

14+
class Grayscale(Transform):
15+
_transformed_types = (
16+
datapoints.Image,
17+
PIL.Image.Image,
18+
is_simple_tensor,
19+
datapoints.Video,
20+
)
21+
22+
def __init__(self, num_output_channels: int = 1):
23+
super().__init__()
24+
self.num_output_channels = num_output_channels
25+
26+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
27+
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
28+
29+
30+
class RandomGrayscale(_RandomApplyTransform):
31+
_transformed_types = (
32+
datapoints.Image,
33+
PIL.Image.Image,
34+
is_simple_tensor,
35+
datapoints.Video,
36+
)
37+
38+
def __init__(self, p: float = 0.1) -> None:
39+
super().__init__(p=p)
40+
41+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
42+
num_input_channels, *_ = query_chw(flat_inputs)
43+
return dict(num_input_channels=num_input_channels)
44+
45+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
46+
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
47+
48+
1449
class ColorJitter(Transform):
1550
def __init__(
1651
self,
Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
import warnings
2-
from typing import Any, Dict, List, Union
2+
from typing import Any, Dict, Union
33

44
import numpy as np
55
import PIL.Image
66
import torch
77

8-
from torchvision.prototype import datapoints
98
from torchvision.prototype.transforms import Transform
109
from torchvision.transforms import functional as _F
11-
from typing_extensions import Literal
12-
13-
from ._transform import _RandomApplyTransform
14-
from .utils import is_simple_tensor, query_chw
1510

1611

1712
class ToTensor(Transform):
@@ -26,78 +21,3 @@ def __init__(self) -> None:
2621

2722
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
2823
return _F.to_tensor(inpt)
29-
30-
31-
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
32-
class Grayscale(Transform):
33-
_transformed_types = (
34-
datapoints.Image,
35-
PIL.Image.Image,
36-
is_simple_tensor,
37-
datapoints.Video,
38-
)
39-
40-
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
41-
deprecation_msg = (
42-
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
43-
f"is deprecated and will be removed in a future release."
44-
)
45-
if num_output_channels == 1:
46-
replacement_msg = (
47-
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
48-
)
49-
else:
50-
replacement_msg = (
51-
"transforms.Compose(\n"
52-
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
53-
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
54-
")"
55-
)
56-
warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}")
57-
58-
super().__init__()
59-
self.num_output_channels = num_output_channels
60-
61-
def _transform(
62-
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
63-
) -> Union[datapoints.ImageType, datapoints.VideoType]:
64-
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
65-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
66-
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
67-
return output
68-
69-
70-
class RandomGrayscale(_RandomApplyTransform):
71-
_transformed_types = (
72-
datapoints.Image,
73-
PIL.Image.Image,
74-
is_simple_tensor,
75-
datapoints.Video,
76-
)
77-
78-
def __init__(self, p: float = 0.1) -> None:
79-
warnings.warn(
80-
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
81-
"Instead, please use\n\n"
82-
"transforms.RandomApply(\n"
83-
" transforms.Compose(\n"
84-
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
85-
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
86-
" )\n"
87-
" p=...,\n"
88-
")"
89-
)
90-
91-
super().__init__(p=p)
92-
93-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
94-
num_input_channels, *_ = query_chw(flat_inputs)
95-
return dict(num_input_channels=num_input_channels)
96-
97-
def _transform(
98-
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
99-
) -> Union[datapoints.ImageType, datapoints.VideoType]:
100-
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
101-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
102-
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
103-
return output

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
posterize_image_pil,
7272
posterize_image_tensor,
7373
posterize_video,
74+
rgb_to_grayscale,
75+
rgb_to_grayscale_image_pil,
76+
rgb_to_grayscale_image_tensor,
7477
solarize,
7578
solarize_image_pil,
7679
solarize_image_tensor,
@@ -167,4 +170,4 @@
167170
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
168171
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
169172

170-
from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip
173+
from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip

torchvision/prototype/transforms/functional/_color.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import PIL.Image
24
import torch
35
from torch.nn.functional import conv2d
@@ -7,10 +9,53 @@
79

810
from torchvision.utils import _log_api_usage_once
911

10-
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
12+
from ._meta import _num_value_bits, convert_dtype_image_tensor
1113
from ._utils import is_simple_tensor
1214

1315

16+
def _rgb_to_grayscale_image_tensor(
17+
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
18+
) -> torch.Tensor:
19+
if image.shape[-3] == 1:
20+
return image.clone()
21+
22+
r, g, b = image.unbind(dim=-3)
23+
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
24+
l_img = l_img.unsqueeze(dim=-3)
25+
if preserve_dtype:
26+
l_img = l_img.to(image.dtype)
27+
if num_output_channels == 3:
28+
l_img = l_img.expand(image.shape)
29+
return l_img
30+
31+
32+
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
33+
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
34+
35+
36+
rgb_to_grayscale_image_pil = _FP.to_grayscale
37+
38+
39+
def rgb_to_grayscale(
40+
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
41+
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
42+
if not torch.jit.is_scripting():
43+
_log_api_usage_once(rgb_to_grayscale)
44+
if num_output_channels not in (1, 3):
45+
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
46+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
47+
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
48+
elif isinstance(inpt, datapoints._datapoint.Datapoint):
49+
return inpt.to_grayscale(num_output_channels=num_output_channels)
50+
elif isinstance(inpt, PIL.Image.Image):
51+
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
52+
else:
53+
raise TypeError(
54+
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
55+
f"but got {type(inpt)} instead."
56+
)
57+
58+
1459
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
1560
ratio = float(ratio)
1661
fp = image1.is_floating_point()
@@ -68,7 +113,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
68113
if c == 1: # Match PIL behaviour
69114
return image
70115

71-
grayscale_image = _rgb_to_gray(image, cast=False)
116+
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
72117
if not image.is_floating_point():
73118
grayscale_image = grayscale_image.floor_()
74119

@@ -110,7 +155,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
110155
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
111156
fp = image.is_floating_point()
112157
if c == 3:
113-
grayscale_image = _rgb_to_gray(image, cast=False)
158+
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
114159
if not fp:
115160
grayscale_image = grayscale_image.floor_()
116161
else:

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from torchvision.prototype import datapoints
88
from torchvision.transforms import functional as _F
99

10-
from ._utils import is_simple_tensor
11-
1210

1311
@torch.jit.unused
1412
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
@@ -24,33 +22,6 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
2422
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
2523

2624

27-
def rgb_to_grayscale(
28-
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
29-
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
30-
old_color_space = None # TODO: remove when un-deprecating
31-
if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(
32-
inpt, (datapoints.Image, datapoints.Video)
33-
):
34-
inpt = inpt.as_subclass(torch.Tensor)
35-
36-
call = ", num_output_channels=3" if num_output_channels == 3 else ""
37-
replacement = (
38-
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
39-
f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
40-
)
41-
if num_output_channels == 3:
42-
replacement = (
43-
f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB"
44-
f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})"
45-
)
46-
warnings.warn(
47-
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
48-
f"Instead, please use `{replacement}`.",
49-
)
50-
51-
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
52-
53-
5425
@torch.jit.unused
5526
def to_tensor(inpt: Any) -> torch.Tensor:
5627
warnings.warn(

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,6 @@ def clamp_bounding_box(
225225
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
226226

227227

228-
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
229-
r, g, b = image.unbind(dim=-3)
230-
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
231-
if cast:
232-
l_img = l_img.to(image.dtype)
233-
l_img = l_img.unsqueeze(dim=-3)
234-
return l_img
235-
236-
237228
def _num_value_bits(dtype: torch.dtype) -> int:
238229
if dtype == torch.uint8:
239230
return 8

0 commit comments

Comments
 (0)