Skip to content

Commit 7d1e564

Browse files
committed
Cleanups, remove v1, enforce float32, and tests
1 parent 65e5677 commit 7d1e564

File tree

11 files changed

+148
-72
lines changed

11 files changed

+148
-72
lines changed

docs/source/transforms.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ Color
350350
v2.RGB
351351
v2.RandomGrayscale
352352
v2.GaussianBlur
353+
v2.GaussianNoise
353354
v2.RandomInvert
354355
v2.RandomPosterize
355356
v2.RandomSolarize
@@ -368,6 +369,7 @@ Functionals
368369
v2.functional.grayscale_to_rgb
369370
v2.functional.to_grayscale
370371
v2.functional.gaussian_blur
372+
v2.functional.gaussian_noise
371373
v2.functional.invert
372374
v2.functional.posterize
373375
v2.functional.solarize
@@ -555,7 +557,6 @@ Color
555557
RandomAdjustSharpness
556558
RandomAutocontrast
557559
RandomEqualize
558-
GaussianNoise
559560

560561
Composition
561562
^^^^^^^^^^^

test/test_transforms_v2.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
111111

112112
input = input.as_subclass(torch.Tensor)
113113
with ignore_jit_no_profile_information_warning():
114-
actual = kernel_scripted(input, *args, **kwargs)
115-
expected = kernel(input, *args, **kwargs)
114+
with freeze_rng_state():
115+
actual = kernel_scripted(input, *args, **kwargs)
116+
with freeze_rng_state():
117+
expected = kernel(input, *args, **kwargs)
116118

117119
assert_close(actual, expected, rtol=rtol, atol=atol)
118120

@@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp
32383240
torch.testing.assert_close(actual, expected, rtol=0, atol=1)
32393241

32403242

3243+
class TestGaussianNoise:
3244+
@pytest.mark.parametrize(
3245+
"make_input",
3246+
[make_image_tensor, make_image, make_video],
3247+
)
3248+
def test_kernel(self, make_input):
3249+
check_kernel(
3250+
F.gaussian_noise,
3251+
make_input(dtype=torch.float32),
3252+
# This cannot pass because the noise on a batch in not per-image
3253+
check_batched_vs_unbatched=False,
3254+
)
3255+
3256+
@pytest.mark.parametrize(
3257+
"make_input",
3258+
[make_image_tensor, make_image, make_video],
3259+
)
3260+
def test_functional(self, make_input):
3261+
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
3262+
3263+
@pytest.mark.parametrize(
3264+
("kernel", "input_type"),
3265+
[
3266+
(F.gaussian_noise, torch.Tensor),
3267+
(F.gaussian_noise_image, tv_tensors.Image),
3268+
(F.gaussian_noise_video, tv_tensors.Video),
3269+
],
3270+
)
3271+
def test_functional_signature(self, kernel, input_type):
3272+
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)
3273+
3274+
@pytest.mark.parametrize(
3275+
"make_input",
3276+
[make_image_tensor, make_image, make_video],
3277+
)
3278+
def test_transform(self, make_input):
3279+
def adapter(_, input, __):
3280+
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
3281+
# Same for PIL images
3282+
for key, value in input.items():
3283+
if isinstance(value, torch.Tensor) and not value.is_floating_point():
3284+
input[key] = value.to(torch.float32)
3285+
if isinstance(value, PIL.Image.Image):
3286+
input[key] = F.pil_to_tensor(value).to(torch.float32)
3287+
return input
3288+
3289+
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)
3290+
3291+
def test_bad_input(self):
3292+
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
3293+
F.gaussian_noise(make_image_pil())
3294+
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
3295+
F.gaussian_noise(make_image(dtype=torch.uint8))
3296+
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
3297+
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)
3298+
3299+
def test_clip(self):
3300+
img = make_image(dtype=torch.float32)
3301+
3302+
out = F.gaussian_noise(img, mean=100, clip=False)
3303+
assert out.min() > 50
3304+
3305+
out = F.gaussian_noise(img, mean=100, clip=True)
3306+
assert (out == 1).all()
3307+
3308+
out = F.gaussian_noise(img, mean=-100, clip=False)
3309+
assert out.min() < -50
3310+
3311+
out = F.gaussian_noise(img, mean=-100, clip=True)
3312+
assert (out == 0).all()
3313+
3314+
32413315
class TestAutoAugmentTransforms:
32423316
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
32433317
# It's typically very hard to test the effect on some parameters without heavy mocking logic.

torchvision/transforms/_functional_pil.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -391,22 +391,3 @@ def equalize(img: Image.Image) -> Image.Image:
391391
if not _is_pil_image(img):
392392
raise TypeError(f"img should be PIL Image. Got {type(img)}")
393393
return ImageOps.equalize(img)
394-
395-
@torch.jit.unused
396-
def gaussian_noise(img: Image.Image, mean: float = 0., var: float = 1.0) -> Image.Image:
397-
if not _is_pil_image(img):
398-
raise TypeError(f"img should be PIL Image. Got {type(img)}")
399-
400-
if var < 0:
401-
raise ValueError(f"var shouldn't be negative. Got {var}")
402-
403-
z = np.random.normal(
404-
loc=mean,
405-
scale=var,
406-
size=(
407-
*get_image_size(img),
408-
get_image_num_channels(img),
409-
),
410-
)
411-
412-
return img + z

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def equalize(img: Tensor) -> Tensor:
15061506
return F_t.equalize(img)
15071507

15081508

1509-
def gaussian_noise(img: Tensor, mean: float = 0., var: float = 1.) -> Tensor:
1509+
def gaussian_noise(img: Tensor, mean: float = 0.0, var: float = 1.0) -> Tensor:
15101510
"""Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default.
15111511
15121512
Args:
@@ -1521,7 +1521,7 @@ def gaussian_noise(img: Tensor, mean: float = 0., var: float = 1.) -> Tensor:
15211521
_log_api_usage_once(gaussian_noise)
15221522
if not isinstance(img, torch.Tensor):
15231523
F_pil.gaussian_noise(img, mean, var)
1524-
1524+
15251525
return F_t.gaussian_noise(img, mean, var)
15261526

15271527

torchvision/transforms/transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
"RandomAutocontrast",
5555
"RandomEqualize",
5656
"ElasticTransform",
57-
"GaussianNoise",
5857
]
5958

6059

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
RandomPhotometricDistort,
1919
RandomPosterize,
2020
RandomSolarize,
21-
GaussianNoise,
2221
RGB,
2322
)
2423
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
@@ -46,6 +45,7 @@
4645
from ._misc import (
4746
ConvertImageDtype,
4847
GaussianBlur,
48+
GaussianNoise,
4949
Identity,
5050
Lambda,
5151
LinearTransformation,

torchvision/transforms/v2/_color.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -374,24 +374,3 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
374374

375375
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
376376
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
377-
378-
379-
class GaussianNoise(Transform):
380-
"""Add gaussian noise to the image. Samples from `N(0, 1)` (standard normal distribution) by default.
381-
382-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
383-
where ... means it can have an arbitrary number of leading dimensions.
384-
If img is PIL Image, it is expected to be in mode "L" or "RGB".
385-
386-
Args:
387-
mean (float): Mean of the sampled gaussian distribution. Default is 0.
388-
var (float): Variance of the sampled gaussian distribution. Default is 1.
389-
"""
390-
391-
def __init__(self, mean: float = 0., var: float = 1.) -> None:
392-
super().__init__()
393-
self.mean = mean
394-
self.var = var
395-
396-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
397-
self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, var=self.var)

torchvision/transforms/v2/_misc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,31 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
205205
return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)
206206

207207

208+
class GaussianNoise(Transform):
209+
"""Add gaussian noise to the image.
210+
211+
The input tensor is expected to be in [..., 1 or 3, H, W] format,
212+
where ... means it can have an arbitrary number of leading dimensions.
213+
214+
The input tensor is also expected to be of float dtype in ``[0, 1]``.
215+
This transform does not support PIL images.
216+
217+
Args:
218+
mean (float): Mean of the sampled normal distribution. Default is 0.
219+
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
220+
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
221+
"""
222+
223+
def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:
224+
super().__init__()
225+
self.mean = mean
226+
self.sigma = sigma
227+
self.clip = clip
228+
229+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
230+
return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip)
231+
232+
208233
class ToDtype(Transform):
209234
"""Converts the input to a specific dtype, optionally scaling the values for images or videos.
210235

torchvision/transforms/v2/functional/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@
136136
gaussian_blur,
137137
gaussian_blur_image,
138138
gaussian_blur_video,
139+
gaussian_noise,
140+
gaussian_noise_image,
141+
gaussian_noise_video,
139142
normalize,
140143
normalize_image,
141144
normalize_video,

torchvision/transforms/v2/functional/_color.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -737,28 +737,3 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
737737
@_register_kernel_internal(permute_channels, tv_tensors.Video)
738738
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
739739
return permute_channels_image(video, permutation=permutation)
740-
741-
def gaussian_noise(inpt: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor:
742-
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
743-
if torch.jit.is_scripting():
744-
return gaussian_noise_image(inpt, mean=mean, var=var)
745-
746-
_log_api_usage_once(gaussian_noise)
747-
748-
kernel = _get_kernel(gaussian_noise, type(inpt))
749-
return kernel(inpt, mean=mean, var=var)
750-
751-
@_register_kernel_internal(gaussian_noise, torch.Tensor)
752-
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
753-
def gaussian_noise_image(image: torch.Tensor, mean: float = 0., var: float = 1.) -> torch.Tensor:
754-
if var < 0:
755-
raise ValueError(f"var shouldn't be negative. Got {var}")
756-
757-
if image.numel() == 0:
758-
return image
759-
760-
z = mean + torch.randn_like(image) * var
761-
762-
return image + z
763-
764-
_gaussian_noise_pil = _register_kernel_internal(gaussian_noise, PIL.Image.Image)(_FP.gaussian_noise)

torchvision/transforms/v2/functional/_misc.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn.functional import conv2d, pad as torch_pad
77

88
from torchvision import tv_tensors
9+
from torchvision.transforms import _functional_pil as _FP
910
from torchvision.transforms._functional_tensor import _max_value
1011
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1112

@@ -181,6 +182,44 @@ def gaussian_blur_video(
181182
return gaussian_blur_image(video, kernel_size, sigma)
182183

183184

185+
def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
186+
"""See :class:`~torchvision.transforms.v2.GaussianNoise`"""
187+
if torch.jit.is_scripting():
188+
return gaussian_noise_image(inpt, mean=mean, sigma=sigma)
189+
190+
_log_api_usage_once(gaussian_noise)
191+
192+
kernel = _get_kernel(gaussian_noise, type(inpt))
193+
return kernel(inpt, mean=mean, sigma=sigma, clip=clip)
194+
195+
196+
@_register_kernel_internal(gaussian_noise, torch.Tensor)
197+
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
198+
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
199+
if not image.is_floating_point():
200+
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
201+
if sigma < 0:
202+
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")
203+
204+
noise = mean + torch.randn_like(image) * sigma
205+
out = image + noise
206+
if clip:
207+
out = torch.clamp(out, 0, 1)
208+
return out
209+
210+
211+
@_register_kernel_internal(gaussian_noise, tv_tensors.Video)
212+
def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
213+
return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip)
214+
215+
216+
@_register_kernel_internal(gaussian_noise, PIL.Image.Image)
217+
def _gaussian_noise_pil(
218+
video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True
219+
) -> PIL.Image.Image:
220+
raise ValueError("Gaussian Noise is not implemented for PIL images.")
221+
222+
184223
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
185224
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
186225
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)