From 8fcab625bece77199500b18b99edf14b9c049e61 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 9 Aug 2021 16:14:46 +0100 Subject: [PATCH 1/2] grouped_transforms init --- test/common_utils.py | 2 +- test/test_transforms.py | 99 ++- torchvision/transforms/transforms.py | 1115 ++++++++++++++++++-------- 3 files changed, 866 insertions(+), 350 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a8f5a91ef6b..cbf4e98f2d8 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -145,7 +145,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu return batch_tensor -assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) +assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0, check_stride=False) def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): diff --git a/test/test_transforms.py b/test/test_transforms.py index 74757bcb4e6..8acb4e99376 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -204,23 +204,23 @@ def test_to_tensor(self, channels): input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) img = transforms.ToPILImage()(input_data) output = trans(img) - torch.testing.assert_close(output, input_data) + torch.testing.assert_close(output, input_data, check_stride=False) ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) output = trans(ndarray) expected_output = ndarray.transpose((2, 0, 1)) / 255.0 - torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False) + torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False) ndarray = np.random.rand(height, width, channels).astype(np.float32) output = trans(ndarray) expected_output = ndarray.transpose((2, 0, 1)) - torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False) + torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False, check_stride=False) # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() img = transforms.ToPILImage()(input_data.mul(255)).convert('1') output = trans(img) - torch.testing.assert_close(input_data, output, check_dtype=False) + torch.testing.assert_close(input_data, output, check_dtype=False, check_stride=False) def test_to_tensor_errors(self): height, width = 4, 4 @@ -257,7 +257,7 @@ def test_pil_to_tensor(self, channels): input_data = torch.ByteTensor(channels, height, width).random_(0, 255) img = transforms.ToPILImage()(input_data) output = trans(img) - torch.testing.assert_close(input_data, output) + torch.testing.assert_close(input_data, output, check_stride=False) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) img = transforms.ToPILImage()(input_data) @@ -269,13 +269,13 @@ def test_pil_to_tensor(self, channels): img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() output = trans(img) # HWC -> CHW expected_output = (input_data * 255).byte() - torch.testing.assert_close(output, expected_output) + torch.testing.assert_close(output, expected_output, check_stride=False) # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() img = transforms.ToPILImage()(input_data.mul(255)).convert('1') output = trans(img).view(torch.uint8).bool().to(torch.uint8) - torch.testing.assert_close(input_data, output) + torch.testing.assert_close(input_data, output, check_stride=False) def test_pil_to_tensor_errors(self): height, width = 4, 4 @@ -419,10 +419,10 @@ def test_pad(self): h_padded = result[:, :padding, :] w_padded = result[:, :, :padding] torch.testing.assert_close( - h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps + h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False, ) torch.testing.assert_close( - w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps + w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps, check_stride=False, ) pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img)) @@ -523,9 +523,9 @@ def test_randomness(fn, trans, config, p): num_samples = 250 counts = 0 for _ in range(num_samples): - tranformation = trans(p=p, **config) - tranformation.__repr__() - out = tranformation(img) + transformation = trans(p=p, **config) + transformation.__repr__() + out = transformation(img) if out == inv_img: counts += 1 @@ -578,7 +578,7 @@ def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_outpu img = transform(img_data) assert img.mode == expected_mode - torch.testing.assert_close(expected_output, to_tensor(img).numpy()) + torch.testing.assert_close(expected_output, to_tensor(img).numpy(), check_stride=False) def test_1_channel_float_tensor_to_pil_image(self): img_data = torch.Tensor(1, 4, 4).uniform_() @@ -616,7 +616,7 @@ def test_2_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(2): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_2_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() @@ -720,7 +720,7 @@ def test_3_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(3): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_3_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() @@ -777,7 +777,7 @@ def test_4_channel_ndarray_to_pil_image(self, expected_mode): assert img.mode == expected_mode split = img.split() for i in range(4): - torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i])) + torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) def test_4_channel_ndarray_to_pil_image_error(self): img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() @@ -1526,7 +1526,7 @@ def test_random_crop(): t = transforms.RandomCrop(48) img = torch.ones(3, 32, 32) - with pytest.raises(ValueError, match=r"Required crop size .+ is larger then input image size .+"): + with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"): t(img) @@ -1653,7 +1653,7 @@ def test_random_erasing(): img = torch.ones(3, 128, 128) t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ]) aspect_ratio = h / w # Add some tolerance due to the rounding and int conversion used in the transform tol = 0.05 @@ -1663,7 +1663,7 @@ def test_random_erasing(): random.seed(42) trial = 1000 for _ in range(trial): - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params_transform(img, t.scale, t.ratio, [t.value, ]) aspect_ratios.append(h / w) count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) @@ -1724,7 +1724,7 @@ def test_randomperspective(): to_pil_image = transforms.ToPILImage() img = to_pil_image(img) perp = transforms.RandomPerspective() - startpoints, endpoints = perp.get_params(width, height, 0.5) + startpoints, endpoints = perp.get_start_endpoints(width, height, 0.5) tr_img = F.perspective(img, startpoints, endpoints) tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints)) tr_img = F.to_tensor(tr_img) @@ -1761,7 +1761,7 @@ def test_randomperspective_fill(mode): pixel = (pixel,) assert pixel == tuple([fill] * num_bands) - startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5) + startpoints, endpoints = transforms.RandomPerspective.get_start_endpoints(width, height, 0.5) tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill) pixel = tr_img.getpixel((0, 0)) @@ -2056,7 +2056,7 @@ def test_random_affine(): t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40]) for _ in range(100): - angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, + angle, translations, scale, shear = t.get_params(img, t.degrees, t.translate, t.scale, t.shear, img_size=img.size) assert -10 < angle < 10 assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}" @@ -2088,5 +2088,60 @@ def test_random_affine(): assert t.interpolation == transforms.InterpolationMode.BILINEAR +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") +@pytest.mark.parametrize('trans, config', [ + (transforms.RandomInvert, {}), + (transforms.RandomPosterize, {"bits": 4}), + (transforms.RandomSolarize, {"threshold": 192}), + (transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), + (transforms.RandomAutocontrast, {}), + (transforms.RandomEqualize, {})]) +@pytest.mark.parametrize('p', (.5, .7)) +def test_reset_randomness(trans, config, p): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) + + num_samples = 250 + counts = 0 + for _ in range(num_samples): + transformation = trans(p=p, **config, reset_auto=False) + transformation.__repr__() + out1 = transformation(img) + assert out1 == transformation(img) + transformation.wipeout_() + out2 = transformation(img) + if out1 == out2: + counts += 1 + + p_repeat = p**2 + (1 - p)**2 + p_value = stats.binom_test(counts, num_samples, p=p_repeat) + random.setstate(random_state) + assert p_value > 0.0001, f'got counts={counts} for num_samples={num_samples}' + + +@pytest.mark.parametrize('trans, config', [ + (transforms.RandomCrop, {'size': 10}), + (transforms.RandomOrder, {"transforms": + [transforms.GaussianBlur(kernel_size=3, reset_auto=False), + transforms.RandomCrop(size=10, reset_auto=False)]}), + (transforms.RandomResizedCrop, {'size': 10}), + (transforms.ColorJitter, {}), + (transforms.RandomRotation, {'degrees': 120}), + (transforms.RandomAffine, {'degrees': 120, 'translate': (0.1, 0.1)}), + (transforms.RandomErasing, {}), + (transforms.GaussianBlur, {"kernel_size": 3})]) +def test_grouptransform(trans, config): + num_samples = 250 + for i in range(num_samples): + t = transforms.GroupTransform(trans(**config, reset_auto=False)) + assert t.stochastic + img = torch.arange(1024, dtype=torch.float).view(1, 32, 32).expand(3, 32, 32).contiguous() + mask = img[:1] + imgs = (img, mask) + imgs_out = t(imgs) + torch.testing.assert_close(imgs_out[0][0], imgs_out[1][0], rtol=1e-6, atol=1e-6, check_stride=False) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 954d5f5f064..80338a75824 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -6,6 +6,7 @@ from typing import Tuple, List, Optional import torch +from torch import nn from torch import Tensor try: @@ -17,15 +18,89 @@ from .functional import InterpolationMode, _interpolation_modes_from_int -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", - "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", - "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", - "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] +__all__ = [ + "Compose", + "GroupTransform", + "ToTensor", + "PILToTensor", + "ConvertImageDtype", + "ToPILImage", + "Normalize", + "Resize", + "Scale", + "CenterCrop", + "Pad", + "Lambda", + "RandomApply", + "RandomChoice", + "RandomOrder", + "RandomCrop", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomResizedCrop", + "RandomSizedCrop", + "FiveCrop", + "TenCrop", + "LinearTransformation", + "ColorJitter", + "RandomRotation", + "RandomAffine", + "Grayscale", + "RandomGrayscale", + "RandomPerspective", + "RandomErasing", + "GaussianBlur", + "InterpolationMode", + "RandomInvert", + "RandomPosterize", + "RandomSolarize", + "RandomAdjustSharpness", + "RandomAutocontrast", + "RandomEqualize", +] + + +class Transform(nn.Module): + stochastic = False + + def __init__(self, reset_auto: bool = True) -> None: + super().__init__() + self.reset_auto = reset_auto + self._initialized = False + + @property + def initialized(self): + return self._initialized + + def _call(self, input, *params): + raise NotImplementedError() + + def get_params(self, input): + return tuple() + def reset_(self, input): + params = self.get_params(input) + if not isinstance(params, tuple): + self.params = (params,) + else: + self.params = params + self._initialized = True + return self.params -class Compose: + def wipeout_(self): + self._initialized = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.initialized: + self.reset_(input) + params = self.params + output = self._call(input, *params) + if self.stochastic and self.reset_auto: + self.wipeout_() + return output + + +class Compose(Transform): """Composes several transforms together. This transform does not support torchscript. Please, see the note below. @@ -39,6 +114,7 @@ class Compose: >>> ]) .. note:: + ===> TODO: check this <=== In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( @@ -52,24 +128,78 @@ class Compose: """ - def __init__(self, transforms): + def __init__(self, transforms, reset_auto=True): + super().__init__(reset_auto=reset_auto) + if not isinstance(transforms, nn.Module) and ( + all(isinstance(t, nn.Module) for t in transforms) + ): + transforms = nn.Sequential(*transforms) + elif not all(isinstance(t, Transform) for t in transforms): + warnings.warn( + "All transforms should be of type torchvision.transforms.Transform. " + "Custom typed transforms will be forbidden in future releases." + ) self.transforms = transforms + for t in transforms: + assert isinstance( + t, Transform + ), f"class {type(t)} must inherit from trochvision.transforms.Transform" - def __call__(self, img): + @property + def stochastic(self): + return any(t.stochastic for t in self.transforms if isinstance(t, Transform)) + + @property + def initialized(self): + return all(t.initialized for t in self.transforms if isinstance(t, Transform)) + + def wipeout_(self): + for t in self.transforms: + if isinstance(t, Transform): + t.wipeout_() + + def _call(self, img): for t in self.transforms: img = t(img) return img def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string -class ToTensor: +class GroupTransform(Transform): + def __init__(self, transform, reset_auto=True): + assert isinstance( + transform, Transform + ), "GroupTransform only accepts transforms of type Transform." + assert not transform.stochastic or not transform.reset_auto + super().__init__(reset_auto=reset_auto) + self.transform = transform + + @property + def stochastic(self): + return self.transform.stochastic + + @property + def initialized(self): + return self.transform.initialized + + def wipeout_(self): + return self.transform.wipeout_() + + def _call(self, imgs): + imgs = [self.transform(img) for img in imgs] + if self.reset_auto: + self.transform.wipeout_() + return imgs + + +class ToTensor(Transform): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. Converts a PIL Image or numpy.ndarray (H x W x C) in the range @@ -86,7 +216,7 @@ class ToTensor: .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation """ - def __call__(self, pic): + def _call(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -97,16 +227,16 @@ def __call__(self, pic): return F.to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" -class PILToTensor: +class PILToTensor(Transform): """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript. Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). """ - def __call__(self, pic): + def _call(self, pic): """ Args: pic (PIL Image): Image to be converted to tensor. @@ -117,10 +247,10 @@ def __call__(self, pic): return F.pil_to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" -class ConvertImageDtype(torch.nn.Module): +class ConvertImageDtype(Transform): """Convert a tensor image to the given ``dtype`` and scale the values accordingly This function does not support PIL Image. @@ -139,15 +269,15 @@ class ConvertImageDtype(torch.nn.Module): of the integer ``dtype``. """ - def __init__(self, dtype: torch.dtype) -> None: - super().__init__() + def __init__(self, dtype: torch.dtype, reset_auto=True) -> None: + super().__init__(reset_auto=reset_auto) self.dtype = dtype - def forward(self, image): + def _call(self, image): return F.convert_image_dtype(image, self.dtype) -class ToPILImage: +class ToPILImage(Transform): """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape @@ -164,10 +294,12 @@ class ToPILImage: .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ - def __init__(self, mode=None): + + def __init__(self, mode=None, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.mode = mode - def __call__(self, pic): + def _call(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. @@ -179,14 +311,14 @@ def __call__(self, pic): return F.to_pil_image(pic, self.mode) def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" if self.mode is not None: - format_string += 'mode={0}'.format(self.mode) - format_string += ')' + format_string += "mode={0}".format(self.mode) + format_string += ")" return format_string -class Normalize(torch.nn.Module): +class Normalize(Transform): """Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` @@ -210,7 +342,7 @@ def __init__(self, mean, std, inplace=False): self.std = std self.inplace = inplace - def forward(self, tensor: Tensor) -> Tensor: + def _call(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be normalized. @@ -221,10 +353,12 @@ def forward(self, tensor: Tensor) -> Tensor: return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + return self.__class__.__name__ + "(mean={0}, std={1})".format( + self.mean, self.std + ) -class Resize(torch.nn.Module): +class Resize(Transform): """Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -268,7 +402,13 @@ class Resize(torch.nn.Module): """ - def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): + def __init__( + self, + size, + interpolation=InterpolationMode.BILINEAR, + max_size=None, + antialias=None, + ): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) @@ -288,7 +428,7 @@ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None self.interpolation = interpolation self.antialias = antialias - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be scaled. @@ -296,25 +436,34 @@ def forward(self, img): Returns: PIL Image or Tensor: Rescaled image. """ - return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + return F.resize( + img, self.size, self.interpolation, self.max_size, self.antialias + ) def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( - self.size, interpolate_str, self.max_size, self.antialias) + return ( + self.__class__.__name__ + + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format( + self.size, interpolate_str, self.max_size, self.antialias + ) + ) class Scale(Resize): """ Note: This transform is deprecated in favor of Resize. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") + warnings.warn( + "The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead." + ) super(Scale, self).__init__(*args, **kwargs) -class CenterCrop(torch.nn.Module): +class CenterCrop(Transform): """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -328,9 +477,11 @@ class CenterCrop(torch.nn.Module): def __init__(self, size): super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -341,10 +492,10 @@ def forward(self, img): return F.center_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) -class Pad(torch.nn.Module): +class Pad(Transform): """Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, @@ -391,17 +542,21 @@ def __init__(self, padding, fill=0, padding_mode="constant"): raise TypeError("Got inappropriate fill arg") if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + raise ValueError( + "Padding mode should be either constant, edge, reflect or symmetric" + ) if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding)) + ) self.padding = padding self.fill = fill self.padding_mode = padding_mode - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be padded. @@ -412,11 +567,15 @@ def forward(self, img): return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ - format(self.padding, self.fill, self.padding_mode) + return ( + self.__class__.__name__ + + "(padding={0}, fill={1}, padding_mode={2})".format( + self.padding, self.fill, self.padding_mode + ) + ) -class Lambda: +class Lambda(Transform): """Apply a user-defined lambda as a transform. This transform does not support torchscript. Args: @@ -424,42 +583,59 @@ class Lambda: """ def __init__(self, lambd): + super().__init__() if not callable(lambd): - raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) + raise TypeError( + "Argument lambd should be callable, got {}".format( + repr(type(lambd).__name__) + ) + ) self.lambd = lambd - def __call__(self, img): + def _call(self, img): return self.lambd(img) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" -class RandomTransforms: +class RandomTransforms(Transform): + stochastic = True """Base class for a list of transformations with randomness Args: transforms (sequence): list of transformations """ - def __init__(self, transforms): + def __init__(self, transforms, reset_auto=True): + super().__init__(reset_auto=reset_auto) if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence") + if not all(reset_auto == t.reset_auto for t in transforms): + raise Exception( + "RandomTransform must have the same reset_auto attribute than provided transforms" + ) self.transforms = transforms - def __call__(self, *args, **kwargs): + def _call(self, *args, **kwargs): raise NotImplementedError() + def wipeout_(self): + super().wipeout_() + for t in self.transforms: + t.wipeout_() + def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string -class RandomApply(torch.nn.Module): +class RandomApply(Transform): + stochastic = True """Apply randomly a list of transformations with a given probability. .. note:: @@ -479,48 +655,66 @@ class RandomApply(torch.nn.Module): p (float): probability """ - def __init__(self, transforms, p=0.5): - super().__init__() + def __init__(self, transforms, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) + if isinstance(transforms, (list, tuple)): + transforms = nn.Sequential(*transforms) + if not isinstance(transforms, nn.Module): + raise TypeError("transfroms should be of type [List, Tuple, nn.Module]") self.transforms = transforms self.p = p - def forward(self, img): - if self.p < torch.rand(1): + def get_params(self, *args): + r = torch.rand(1) + return r + + def _call(self, img, r): + if self.p < r: return img - for t in self.transforms: - img = t(img) + img = self.transforms(img) return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += '\n p={}'.format(self.p) + format_string = self.__class__.__name__ + "(" + format_string += "\n p={}".format(self.p) for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string class RandomOrder(RandomTransforms): + stochastic = True """Apply a list of transformations in a random order. This transform does not support torchscript. """ - def __call__(self, img): + + def get_params(self, *args): order = list(range(len(self.transforms))) random.shuffle(order) + return order + + def _call(self, img, order): for i in order: img = self.transforms[i](img) return img class RandomChoice(RandomTransforms): + stochastic = True """Apply single transformation randomly picked from a list. This transform does not support torchscript. """ - def __call__(self, img): + + def get_params(self, *args): t = random.choice(self.transforms) + return t + + def _call(self, img, t): return t(img) -class RandomCrop(torch.nn.Module): +class RandomCrop(Transform): + stochastic = True """Crop the given image at a random location. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, @@ -564,8 +758,7 @@ class RandomCrop(torch.nn.Module): will result in [2, 1, 1, 2, 3, 4, 4, 3] """ - @staticmethod - def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + def get_params(self, img: Tensor) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random crop. Args: @@ -576,40 +769,48 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = F._get_image_size(img) - th, tw = output_size + th, tw = self.size if h + 1 < th or w + 1 < tw: raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) + "Required crop size {} is larger than input image size {}".format( + (th, tw), (h, w) + ) ) if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): - super().__init__() + def __init__( + self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode="constant", + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) - self.size = tuple(_setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - )) + self.size = tuple( + _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) + ) self.padding = padding self.pad_if_needed = pad_if_needed self.fill = fill self.padding_mode = padding_mode + self.register_forward_pre_hook(RandomCrop._pad) - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ + def _pad(self, img_tuple): + assert len(img_tuple) == 1 + img = img_tuple[0] if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) @@ -622,16 +823,27 @@ def forward(self, img): if self.pad_if_needed and height < self.size[0]: padding = [0, self.size[0] - height] img = F.pad(img, padding, self.fill, self.padding_mode) + return img - i, j, h, w = self.get_params(img, self.size) + def _call(self, img, i, j, h, w): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ return F.crop(img, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) + return self.__class__.__name__ + "(size={0}, padding={1})".format( + self.size, self.padding + ) -class RandomHorizontalFlip(torch.nn.Module): +class RandomHorizontalFlip(Transform): + stochastic = True """Horizontally flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -641,11 +853,14 @@ class RandomHorizontalFlip(torch.nn.Module): p (float): probability of the image being flipped. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, *args): + return torch.rand(1) + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be flipped. @@ -653,15 +868,16 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly flipped image. """ - if torch.rand(1) < self.p: + if r < self.p: return F.hflip(img) return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) -class RandomVerticalFlip(torch.nn.Module): +class RandomVerticalFlip(Transform): + stochastic = True """Vertically flip the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -671,11 +887,14 @@ class RandomVerticalFlip(torch.nn.Module): p (float): probability of the image being flipped. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, *args): + return torch.rand(1) + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be flipped. @@ -683,15 +902,16 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly flipped image. """ - if torch.rand(1) < self.p: + if r < self.p: return F.vflip(img) return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) -class RandomPerspective(torch.nn.Module): +class RandomPerspective(Transform): + stochastic = True """Performs a random perspective transformation of the given image with a given probability. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -708,8 +928,15 @@ class RandomPerspective(torch.nn.Module): image. Default is ``0``. If given a number, the value is used for all bands respectively. """ - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): - super().__init__() + def __init__( + self, + distortion_scale=0.5, + p=0.5, + interpolation=InterpolationMode.BILINEAR, + fill=0, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) self.p = p # Backward compatibility with integer value @@ -730,7 +957,18 @@ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode. self.fill = fill - def forward(self, img): + def get_params(self, img): + r = torch.rand(1) + if r < self.p: + width, height = F._get_image_size(img) + startpoints, endpoints = self.get_start_endpoints( + width, height, self.distortion_scale + ) + else: + startpoints, endpoints = None, None + return r, startpoints, endpoints + + def _call(self, img, r, startpoints, endpoints): """ Args: img (PIL Image or Tensor): Image to be Perspectively transformed. @@ -746,14 +984,14 @@ def forward(self, img): else: fill = [float(f) for f in fill] - if torch.rand(1) < self.p: - width, height = F._get_image_size(img) - startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + if r < self.p: return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @staticmethod - def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: + def get_start_endpoints( + width: int, height: int, distortion_scale: float + ) -> Tuple[List[List[int]], List[List[int]]]: """Get parameters for ``perspective`` for a random perspective transform. Args: @@ -768,30 +1006,63 @@ def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[L half_height = height // 2 half_width = width // 2 topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int( + torch.randint( + 0, int(distortion_scale * half_width) + 1, size=(1,) + ).item() + ), + int( + torch.randint( + 0, int(distortion_scale * half_height) + 1, size=(1,) + ).item() + ), ] topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int( + torch.randint( + width - int(distortion_scale * half_width) - 1, width, size=(1,) + ).item() + ), + int( + torch.randint( + 0, int(distortion_scale * half_height) + 1, size=(1,) + ).item() + ), ] botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int( + torch.randint( + width - int(distortion_scale * half_width) - 1, width, size=(1,) + ).item() + ), + int( + torch.randint( + height - int(distortion_scale * half_height) - 1, height, size=(1,) + ).item() + ), ] botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int( + torch.randint( + 0, int(distortion_scale * half_width) + 1, size=(1,) + ).item() + ), + int( + torch.randint( + height - int(distortion_scale * half_height) - 1, height, size=(1,) + ).item() + ), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) -class RandomResizedCrop(torch.nn.Module): +class RandomResizedCrop(Transform): + stochastic = True """Crop a random portion of image and resize it to a given size. If the image is torch Tensor, it is expected @@ -820,9 +1091,18 @@ class RandomResizedCrop(torch.nn.Module): """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=InterpolationMode.BILINEAR, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") @@ -843,9 +1123,8 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat self.scale = scale self.ratio = ratio - @staticmethod def get_params( - img: Tensor, scale: List[float], ratio: List[float] + self, img: Tensor, scale: List[float] = [], ratio: List[float] = [] ) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. @@ -858,6 +1137,10 @@ def get_params( tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ + if not len(scale): + scale = self.scale + if not len(ratio): + ratio = self.ratio width, height = F._get_image_size(img) area = height * width @@ -891,7 +1174,7 @@ def get_params( j = (width - w) // 2 return i, j, h, w - def forward(self, img): + def _call(self, img, i, j, h, w): """ Args: img (PIL Image or Tensor): Image to be cropped and resized. @@ -899,15 +1182,14 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly cropped and resized image. """ - i, j, h, w = self.get_params(img, self.scale, self.ratio) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) - format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) - format_string += ', interpolation={0})'.format(interpolate_str) + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) return format_string @@ -915,13 +1197,16 @@ class RandomSizedCrop(RandomResizedCrop): """ Note: This transform is deprecated in favor of RandomResizedCrop. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + - "please use transforms.RandomResizedCrop instead.") + warnings.warn( + "The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead." + ) super(RandomSizedCrop, self).__init__(*args, **kwargs) -class FiveCrop(torch.nn.Module): +class FiveCrop(Transform): """Crop the given image into four corners and the central crop. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading @@ -951,9 +1236,11 @@ class FiveCrop(torch.nn.Module): def __init__(self, size): super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -964,10 +1251,10 @@ def forward(self, img): return F.five_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) -class TenCrop(torch.nn.Module): +class TenCrop(Transform): """Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). If the image is torch Tensor, it is expected @@ -999,10 +1286,12 @@ class TenCrop(torch.nn.Module): def __init__(self, size, vertical_flip=False): super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.size = _setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + ) self.vertical_flip = vertical_flip - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. @@ -1013,10 +1302,12 @@ def forward(self, img): return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): - return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format( + self.size, self.vertical_flip + ) -class LinearTransformation(torch.nn.Module): +class LinearTransformation(Transform): """Transform a tensor image with a square transformation matrix and a mean_vector computed offline. This transform does not support PIL Image. @@ -1038,22 +1329,30 @@ class LinearTransformation(torch.nn.Module): def __init__(self, transformation_matrix, mean_vector): super().__init__() if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + raise ValueError( + "transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()) + ) if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + - " as any one of the dimensions of the transformation_matrix [{}]" - .format(tuple(transformation_matrix.size()))) + raise ValueError( + "mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]".format( + tuple(transformation_matrix.size()) + ) + ) if transformation_matrix.device != mean_vector.device: - raise ValueError("Input tensors should be on the same device. Got {} and {}" - .format(transformation_matrix.device, mean_vector.device)) + raise ValueError( + "Input tensors should be on the same device. Got {} and {}".format( + transformation_matrix.device, mean_vector.device + ) + ) self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def forward(self, tensor: Tensor) -> Tensor: + def _call(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image to be whitened. @@ -1064,13 +1363,17 @@ def forward(self, tensor: Tensor) -> Tensor: shape = tensor.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0]) + ) if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device) + ) flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) @@ -1078,13 +1381,14 @@ def forward(self, tensor: Tensor) -> Tensor: return tensor def __repr__(self): - format_string = self.__class__.__name__ + '(transformation_matrix=' - format_string += (str(self.transformation_matrix.tolist()) + ')') - format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + format_string = self.__class__.__name__ + "(transformation_matrix=" + format_string += str(self.transformation_matrix.tolist()) + ")" + format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")" return format_string -class ColorJitter(torch.nn.Module): +class ColorJitter(Transform): + stochastic = True """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1105,19 +1409,24 @@ class ColorJitter(torch.nn.Module): Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - super().__init__() - self.brightness = self._check_input(brightness, 'brightness') - self.contrast = self._check_input(contrast, 'contrast') - self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, reset_auto=True): + super().__init__(reset_auto=reset_auto) + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input( + hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False + ) @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + def _check_input( + self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True + ): if isinstance(value, numbers.Number): if value < 0: - raise ValueError("If {} is a single number, it must be non negative.".format(name)) + raise ValueError( + "If {} is a single number, it must be non negative.".format(name) + ) value = [center - float(value), center + float(value)] if clip_first_on_zero: value[0] = max(value[0], 0.0) @@ -1125,7 +1434,11 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError("{} values should be between {}".format(name, bound)) else: - raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) + raise TypeError( + "{} should be a single number or a list/tuple with length 2.".format( + name + ) + ) # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing @@ -1133,12 +1446,16 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs value = None return value - @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + def get_params( + self, + img, + brightness: Optional[List[float]] = [], + contrast: Optional[List[float]] = [], + saturation: Optional[List[float]] = [], + hue: Optional[List[float]] = [], + ) -> Tuple[ + Tensor, Optional[float], Optional[float], Optional[float], Optional[float] + ]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1155,16 +1472,44 @@ def get_params(brightness: Optional[List[float]], tuple: The parameters used to apply the randomized transform along with their random order. """ + if not len(brightness): + brightness = self.brightness + if not len(contrast): + contrast = self.contrast + if not len(saturation): + saturation = self.saturation + if not len(hue): + hue = self.hue fn_idx = torch.randperm(4) - b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + b = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + c = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + s = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) return fn_idx, b, c, s, h - def forward(self, img): + def _call( + self, + img, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ): """ Args: img (PIL Image or Tensor): Input image. @@ -1172,8 +1517,6 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1188,15 +1531,16 @@ def forward(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += 'brightness={0}'.format(self.brightness) - format_string += ', contrast={0}'.format(self.contrast) - format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) + format_string = self.__class__.__name__ + "(" + format_string += "brightness={0}".format(self.brightness) + format_string += ", contrast={0}".format(self.contrast) + format_string += ", saturation={0}".format(self.saturation) + format_string += ", hue={0})".format(self.hue) return format_string -class RandomRotation(torch.nn.Module): +class RandomRotation(Transform): + stochastic = True """Rotate the image by angle. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1225,9 +1569,16 @@ class RandomRotation(torch.nn.Module): """ def __init__( - self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None + self, + degrees, + interpolation=InterpolationMode.NEAREST, + expand=False, + center=None, + fill=0, + resample=None, + reset_auto=True, ): - super().__init__() + super().__init__(reset_auto=reset_auto) if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" @@ -1242,10 +1593,10 @@ def __init__( ) interpolation = _interpolation_modes_from_int(interpolation) - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if center is not None: - _check_sequence_input(center, "center", req_sizes=(2, )) + _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center @@ -1259,17 +1610,20 @@ def __init__( self.fill = fill - @staticmethod - def get_params(degrees: List[float]) -> float: + def get_params(self, img, degrees: List[float] = []) -> float: """Get parameters for ``rotate`` for a random rotation. Returns: float: angle parameter to be passed to ``rotate`` for random rotation. """ - angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + if not len(degrees): + degrees = self.degrees + angle = float( + torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item() + ) return angle - def forward(self, img): + def _call(self, img, angle): """ Args: img (PIL Image or Tensor): Image to be rotated. @@ -1283,24 +1637,24 @@ def forward(self, img): fill = [float(fill)] * F._get_image_num_channels(img) else: fill = [float(f) for f in fill] - angle = self.get_params(self.degrees) return F.rotate(img, angle, self.resample, self.expand, self.center, fill) def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) - format_string += ', interpolation={0}'.format(interpolate_str) - format_string += ', expand={0}'.format(self.expand) + format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees) + format_string += ", interpolation={0}".format(interpolate_str) + format_string += ", expand={0}".format(self.expand) if self.center is not None: - format_string += ', center={0}'.format(self.center) + format_string += ", center={0}".format(self.center) if self.fill is not None: - format_string += ', fill={0}'.format(self.fill) - format_string += ')' + format_string += ", fill={0}".format(self.fill) + format_string += ")" return format_string -class RandomAffine(torch.nn.Module): +class RandomAffine(Transform): + stochastic = True """Random affine transformation of the image keeping center invariant. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1337,10 +1691,18 @@ class RandomAffine(torch.nn.Module): """ def __init__( - self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, - fillcolor=None, resample=None + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + fillcolor=None, + resample=None, + reset_auto=True, ): - super().__init__() + super().__init__(reset_auto=reset_auto) if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" @@ -1361,17 +1723,17 @@ def __init__( ) fill = fillcolor - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: - _check_sequence_input(translate, "translate", req_sizes=(2, )) + _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: - _check_sequence_input(scale, "scale", req_sizes=(2, )) + _check_sequence_input(scale, "scale", req_sizes=(2,)) for s in scale: if s <= 0: raise ValueError("scale values should be positive") @@ -1391,20 +1753,34 @@ def __init__( self.fillcolor = self.fill = fill - @staticmethod def get_params( - degrees: List[float], - translate: Optional[List[float]], - scale_ranges: Optional[List[float]], - shears: Optional[List[float]], - img_size: List[int] + self, + img, + degrees: List[float] = [], + translate: Optional[List[float]] = [], + scale_ranges: Optional[List[float]] = [], + shears: Optional[List[float]] = [], + img_size: List[int] = [], ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: """Get parameters for affine transformation Returns: params to be passed to the affine transformation """ - angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + if not len(degrees): + degrees = self.degrees + if not len(translate): + translate = self.translate + if not len(scale_ranges): + scale_ranges = self.scale + if not len(shears): + shears = self.shear + if not len(img_size): + img_size = F._get_image_size(img) + + angle = float( + torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item() + ) if translate is not None: max_dx = float(translate[0] * img_size[0]) max_dy = float(translate[1] * img_size[1]) @@ -1415,7 +1791,9 @@ def get_params( translations = (0, 0) if scale_ranges is not None: - scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) + scale = float( + torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item() + ) else: scale = 1.0 @@ -1429,7 +1807,7 @@ def get_params( return angle, translations, scale, shear - def forward(self, img): + def _call(self, img, angle, translations, scale, shear): """ img (PIL Image or Tensor): Image to be transformed. @@ -1443,31 +1821,35 @@ def forward(self, img): else: fill = [float(f) for f in fill] - img_size = F._get_image_size(img) - - ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - - return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) + return F.affine( + img, + angle, + translations, + scale, + shear, + interpolation=self.interpolation, + fill=fill, + ) def __repr__(self): - s = '{name}(degrees={degrees}' + s = "{name}(degrees={degrees}" if self.translate is not None: - s += ', translate={translate}' + s += ", translate={translate}" if self.scale is not None: - s += ', scale={scale}' + s += ", scale={scale}" if self.shear is not None: - s += ', shear={shear}' + s += ", shear={shear}" if self.interpolation != InterpolationMode.NEAREST: - s += ', interpolation={interpolation}' + s += ", interpolation={interpolation}" if self.fill != 0: - s += ', fill={fill}' - s += ')' + s += ", fill={fill}" + s += ")" d = dict(self.__dict__) - d['interpolation'] = self.interpolation.value + d["interpolation"] = self.interpolation.value return s.format(name=self.__class__.__name__, **d) -class Grayscale(torch.nn.Module): +class Grayscale(Transform): """Convert image to grayscale. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions @@ -1487,7 +1869,7 @@ def __init__(self, num_output_channels=1): super().__init__() self.num_output_channels = num_output_channels - def forward(self, img): + def _call(self, img): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. @@ -1498,10 +1880,13 @@ def forward(self, img): return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self): - return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + return self.__class__.__name__ + "(num_output_channels={0})".format( + self.num_output_channels + ) -class RandomGrayscale(torch.nn.Module): +class RandomGrayscale(Transform): + stochastic = True """Randomly convert image to grayscale with a probability of p (default 0.1). If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions @@ -1517,11 +1902,15 @@ class RandomGrayscale(torch.nn.Module): """ - def __init__(self, p=0.1): - super().__init__() + def __init__(self, p=0.1, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1) + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be converted to grayscale. @@ -1530,16 +1919,17 @@ def forward(self, img): PIL Image or Tensor: Randomly grayscaled image. """ num_output_channels = F._get_image_num_channels(img) - if torch.rand(1) < self.p: + if r < self.p: return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img def __repr__(self): - return self.__class__.__name__ + '(p={0})'.format(self.p) + return self.__class__.__name__ + "(p={0})".format(self.p) -class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. +class RandomErasing(Transform): + stochastic = True + """Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 @@ -1565,10 +1955,20 @@ class RandomErasing(torch.nn.Module): >>> ]) """ - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): - super().__init__() + def __init__( + self, + p=0.5, + scale=(0.02, 0.33), + ratio=(0.3, 3.3), + value=0, + inplace=False, + reset_auto=True, + ): + super().__init__(reset_auto=reset_auto) if not isinstance(value, (numbers.Number, str, tuple, list)): - raise TypeError("Argument value should be either a number or str or a sequence") + raise TypeError( + "Argument value should be either a number or str or a sequence" + ) if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") if not isinstance(scale, (tuple, list)): @@ -1588,9 +1988,12 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace self.value = value self.inplace = inplace - @staticmethod - def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + def get_params_transform( + self, + img: Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + value: Optional[List[float]] = None, ) -> Tuple[int, int, int, int, Tensor]: """Get parameters for ``erase`` for a random erasing. @@ -1625,26 +2028,22 @@ def get_params( else: v = torch.tensor(value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1, )).item() - j = torch.randint(0, img_w - w + 1, size=(1, )).item() + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() return i, j, h, w, v # Return original image return 0, 0, img_h, img_w, img - def forward(self, img): - """ - Args: - img (Tensor): Tensor image to be erased. - - Returns: - img (Tensor): Erased Tensor image. - """ - if torch.rand(1) < self.p: + def get_params(self, img): + r = torch.rand(1) + if r < self.p: # cast self.value to script acceptable type if isinstance(self.value, (int, float)): - value = [self.value, ] + value = [ + self.value, + ] elif isinstance(self.value, str): value = None elif isinstance(self.value, tuple): @@ -1658,20 +2057,35 @@ def forward(self, img): "{} (number of input channels)".format(img.shape[-3]) ) - x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + x, y, h, w, v = self.get_params_transform( + img, scale=self.scale, ratio=self.ratio, value=value + ) + return r, x, y, h, w, v + return r, None, None, None, None, None + + def _call(self, img, r, x, y, h, w, v): + """ + Args: + img (Tensor): Tensor image to be erased. + + Returns: + img (Tensor): Erased Tensor image. + """ + if r < self.p: return F.erase(img, x, y, h, w, v, self.inplace) return img def __repr__(self): - s = '(p={}, '.format(self.p) - s += 'scale={}, '.format(self.scale) - s += 'ratio={}, '.format(self.ratio) - s += 'value={}, '.format(self.value) - s += 'inplace={})'.format(self.inplace) + s = "(p={}, ".format(self.p) + s += "scale={}, ".format(self.scale) + s += "ratio={}, ".format(self.ratio) + s += "value={}, ".format(self.value) + s += "inplace={})".format(self.inplace) return self.__class__.__name__ + s -class GaussianBlur(torch.nn.Module): +class GaussianBlur(Transform): + stochastic = True """Blurs image with randomly chosen Gaussian blur. If the image is torch Tensor, it is expected to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1688,27 +2102,36 @@ class GaussianBlur(torch.nn.Module): """ - def __init__(self, kernel_size, sigma=(0.1, 2.0)): - super().__init__() - self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + def __init__(self, kernel_size, sigma=(0.1, 2.0), reset_auto=True): + super().__init__(reset_auto=reset_auto) + self.kernel_size = _setup_size( + kernel_size, "Kernel size should be a tuple/list of two integers" + ) for ks in self.kernel_size: if ks <= 0 or ks % 2 == 0: - raise ValueError("Kernel size value should be an odd and positive number.") + raise ValueError( + "Kernel size value should be an odd and positive number." + ) if isinstance(sigma, numbers.Number): if sigma <= 0: raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: - raise ValueError("sigma values should be positive and of the form (min, max).") + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError( + "sigma values should be positive and of the form (min, max)." + ) else: - raise ValueError("sigma should be a single number or a list/tuple with length 2.") + raise ValueError( + "sigma should be a single number or a list/tuple with length 2." + ) self.sigma = sigma - @staticmethod - def get_params(sigma_min: float, sigma_max: float) -> float: + def get_params( + self, img, sigma_min: float = -1.0, sigma_max: float = -1.0 + ) -> float: """Choose sigma for random gaussian blurring. Args: @@ -1718,9 +2141,13 @@ def get_params(sigma_min: float, sigma_max: float) -> float: Returns: float: Standard deviation to be passed to calculate kernel for gaussian blurring. """ + if sigma_min == -1.0: + sigma_min = self.sigma[0] + if sigma_max == -1.0: + sigma_max = self.sigma[1] return torch.empty(1).uniform_(sigma_min, sigma_max).item() - def forward(self, img: Tensor) -> Tensor: + def _call(self, img: Tensor, sigma: float) -> Tensor: """ Args: img (PIL Image or Tensor): image to be blurred. @@ -1728,12 +2155,11 @@ def forward(self, img: Tensor) -> Tensor: Returns: PIL Image or Tensor: Gaussian blurred image """ - sigma = self.get_params(self.sigma[0], self.sigma[1]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self): - s = '(kernel_size={}, '.format(self.kernel_size) - s += 'sigma={})'.format(self.sigma) + s = "(kernel_size={}, ".format(self.kernel_size) + s += "sigma={})".format(self.sigma) return self.__class__.__name__ + s @@ -1751,17 +2177,21 @@ def _setup_size(size, error_msg): def _check_sequence_input(x, name, req_sizes): - msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) + msg = ( + req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) + ) if not isinstance(x, Sequence): raise TypeError("{} should be a sequence of length {}.".format(name, msg)) if len(x) not in req_sizes: raise ValueError("{} should be sequence of length {}.".format(name, msg)) -def _setup_angle(x, name, req_sizes=(2, )): +def _setup_angle(x, name, req_sizes=(2,)): if isinstance(x, numbers.Number): if x < 0: - raise ValueError("If {} is a single number, it must be positive.".format(name)) + raise ValueError( + "If {} is a single number, it must be positive.".format(name) + ) x = [-x, x] else: _check_sequence_input(x, name, req_sizes) @@ -1769,7 +2199,8 @@ def _setup_angle(x, name, req_sizes=(2, )): return [float(d) for d in x] -class RandomInvert(torch.nn.Module): +class RandomInvert(Transform): + stochastic = True """Inverts the colors of the given image randomly with a given probability. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. @@ -1779,11 +2210,15 @@ class RandomInvert(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be inverted. @@ -1791,15 +2226,16 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly color inverted image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.invert(img) return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) -class RandomPosterize(torch.nn.Module): +class RandomPosterize(Transform): + stochastic = True """Posterize the image randomly with a given probability by reducing the number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1810,12 +2246,16 @@ class RandomPosterize(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, bits, p=0.5): - super().__init__() + def __init__(self, bits, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.bits = bits self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be posterized. @@ -1823,15 +2263,16 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly posterized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.posterize(img, self.bits) return img def __repr__(self): - return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p) -class RandomSolarize(torch.nn.Module): +class RandomSolarize(Transform): + stochastic = True """Solarize the image randomly with a given probability by inverting all pixel values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. @@ -1842,12 +2283,16 @@ class RandomSolarize(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, threshold, p=0.5): - super().__init__() + def __init__(self, threshold, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.threshold = threshold self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be solarized. @@ -1855,15 +2300,17 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly solarized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.solarize(img, self.threshold) return img def __repr__(self): - return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + return self.__class__.__name__ + "(threshold={},p={})".format( + self.threshold, self.p + ) -class RandomAdjustSharpness(torch.nn.Module): +class RandomAdjustSharpness(Transform): """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1874,12 +2321,16 @@ class RandomAdjustSharpness(torch.nn.Module): p (float): probability of the image being color inverted. Default value is 0.5 """ - def __init__(self, sharpness_factor, p=0.5): - super().__init__() + def __init__(self, sharpness_factor, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.sharpness_factor = sharpness_factor self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be sharpened. @@ -1887,15 +2338,17 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly sharpened image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.adjust_sharpness(img, self.sharpness_factor) return img def __repr__(self): - return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + return self.__class__.__name__ + "(sharpness_factor={},p={})".format( + self.sharpness_factor, self.p + ) -class RandomAutocontrast(torch.nn.Module): +class RandomAutocontrast(Transform): """Autocontrast the pixels of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1905,11 +2358,15 @@ class RandomAutocontrast(torch.nn.Module): p (float): probability of the image being autocontrasted. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be autocontrasted. @@ -1917,15 +2374,15 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly autocontrasted image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.autocontrast(img) return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) -class RandomEqualize(torch.nn.Module): +class RandomEqualize(Transform): """Equalize the histogram of the given image randomly with a given probability. If the image is torch Tensor, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1935,11 +2392,15 @@ class RandomEqualize(torch.nn.Module): p (float): probability of the image being equalized. Default value is 0.5 """ - def __init__(self, p=0.5): - super().__init__() + def __init__(self, p=0.5, reset_auto=True): + super().__init__(reset_auto=reset_auto) self.p = p - def forward(self, img): + def get_params(self, img): + r = torch.rand(1).item() + return r + + def _call(self, img, r): """ Args: img (PIL Image or Tensor): Image to be equalized. @@ -1947,9 +2408,9 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly equalized image. """ - if torch.rand(1).item() < self.p: + if r < self.p: return F.equalize(img) return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) From e055536d990765c0a7a50d7e8e824a30f62711f2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Aug 2021 11:01:48 +0100 Subject: [PATCH 2/2] removing unnecessary formatting changes --- torchvision/transforms/transforms.py | 452 +++++++++------------------ 1 file changed, 143 insertions(+), 309 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 80338a75824..656168326ab 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -18,46 +18,12 @@ from .functional import InterpolationMode, _interpolation_modes_from_int -__all__ = [ - "Compose", - "GroupTransform", - "ToTensor", - "PILToTensor", - "ConvertImageDtype", - "ToPILImage", - "Normalize", - "Resize", - "Scale", - "CenterCrop", - "Pad", - "Lambda", - "RandomApply", - "RandomChoice", - "RandomOrder", - "RandomCrop", - "RandomHorizontalFlip", - "RandomVerticalFlip", - "RandomResizedCrop", - "RandomSizedCrop", - "FiveCrop", - "TenCrop", - "LinearTransformation", - "ColorJitter", - "RandomRotation", - "RandomAffine", - "Grayscale", - "RandomGrayscale", - "RandomPerspective", - "RandomErasing", - "GaussianBlur", - "InterpolationMode", - "RandomInvert", - "RandomPosterize", - "RandomSolarize", - "RandomAdjustSharpness", - "RandomAutocontrast", - "RandomEqualize", -] +__all__ = ["Compose", "GroupTransform", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", + "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] class Transform(nn.Module): @@ -114,7 +80,6 @@ class Compose(Transform): >>> ]) .. note:: - ===> TODO: check this <=== In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( @@ -164,11 +129,11 @@ def _call(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + "(" + format_string = self.__class__.__name__ + '(' for t in self.transforms: - format_string += "\n" - format_string += " {0}".format(t) - format_string += "\n)" + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' return format_string @@ -227,7 +192,7 @@ def _call(self, pic): return F.to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + "()" + return self.__class__.__name__ + '()' class PILToTensor(Transform): @@ -247,7 +212,7 @@ def _call(self, pic): return F.pil_to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + "()" + return self.__class__.__name__ + '()' class ConvertImageDtype(Transform): @@ -269,7 +234,7 @@ class ConvertImageDtype(Transform): of the integer ``dtype``. """ - def __init__(self, dtype: torch.dtype, reset_auto=True) -> None: + def __init__(self, dtype: torch.dtype, reset_auto: bool=True) -> None: super().__init__(reset_auto=reset_auto) self.dtype = dtype @@ -311,10 +276,10 @@ def _call(self, pic): return F.to_pil_image(pic, self.mode) def __repr__(self): - format_string = self.__class__.__name__ + "(" + format_string = self.__class__.__name__ + '(' if self.mode is not None: - format_string += "mode={0}".format(self.mode) - format_string += ")" + format_string += 'mode={0}'.format(self.mode) + format_string += ')' return format_string @@ -353,9 +318,7 @@ def _call(self, tensor: Tensor) -> Tensor: return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + "(mean={0}, std={1})".format( - self.mean, self.std - ) + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) class Resize(Transform): @@ -402,13 +365,7 @@ class Resize(Transform): """ - def __init__( - self, - size, - interpolation=InterpolationMode.BILINEAR, - max_size=None, - antialias=None, - ): + def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) @@ -436,30 +393,21 @@ def _call(self, img): Returns: PIL Image or Tensor: Rescaled image. """ - return F.resize( - img, self.size, self.interpolation, self.max_size, self.antialias - ) - + return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + def __repr__(self): interpolate_str = self.interpolation.value - return ( - self.__class__.__name__ - + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format( - self.size, interpolate_str, self.max_size, self.antialias - ) - ) + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( + self.size, interpolate_str, self.max_size, self.antialias) class Scale(Resize): """ Note: This transform is deprecated in favor of Resize. """ - def __init__(self, *args, **kwargs): - warnings.warn( - "The use of the transforms.Scale transform is deprecated, " - + "please use transforms.Resize instead." - ) + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") super(Scale, self).__init__(*args, **kwargs) @@ -477,9 +425,7 @@ class CenterCrop(Transform): def __init__(self, size): super().__init__() - self.size = _setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - ) + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _call(self, img): """ @@ -492,7 +438,7 @@ def _call(self, img): return F.center_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + "(size={0})".format(self.size) + return self.__class__.__name__ + '(size={0})'.format(self.size) class Pad(Transform): @@ -542,16 +488,12 @@ def __init__(self, padding, fill=0, padding_mode="constant"): raise TypeError("Got inappropriate fill arg") if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError( - "Padding mode should be either constant, edge, reflect or symmetric" - ) + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError( - "Padding must be an int or a 1, 2, or 4 element tuple, not a " - + "{} element tuple".format(len(padding)) - ) - + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + self.padding = padding self.fill = fill self.padding_mode = padding_mode @@ -567,13 +509,9 @@ def _call(self, img): return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): - return ( - self.__class__.__name__ - + "(padding={0}, fill={1}, padding_mode={2})".format( - self.padding, self.fill, self.padding_mode - ) - ) - + return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ + format(self.padding, self.fill, self.padding_mode) + class Lambda(Transform): """Apply a user-defined lambda as a transform. This transform does not support torchscript. @@ -585,18 +523,14 @@ class Lambda(Transform): def __init__(self, lambd): super().__init__() if not callable(lambd): - raise TypeError( - "Argument lambd should be callable, got {}".format( - repr(type(lambd).__name__) - ) - ) + raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) self.lambd = lambd def _call(self, img): return self.lambd(img) def __repr__(self): - return self.__class__.__name__ + "()" + return self.__class__.__name__ + '()' class RandomTransforms(Transform): @@ -626,11 +560,11 @@ def wipeout_(self): t.wipeout_() def __repr__(self): - format_string = self.__class__.__name__ + "(" + format_string = self.__class__.__name__ + '(' for t in self.transforms: - format_string += "\n" - format_string += " {0}".format(t) - format_string += "\n)" + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' return format_string @@ -675,12 +609,12 @@ def _call(self, img, r): return img def __repr__(self): - format_string = self.__class__.__name__ + "(" - format_string += "\n p={}".format(self.p) + format_string = self.__class__.__name__ + '(' + format_string += '\n p={}'.format(self.p) for t in self.transforms: - format_string += "\n" - format_string += " {0}".format(t) - format_string += "\n)" + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' return format_string @@ -773,16 +707,14 @@ def get_params(self, img: Tensor) -> Tuple[int, int, int, int]: if h + 1 < th or w + 1 < tw: raise ValueError( - "Required crop size {} is larger than input image size {}".format( - (th, tw), (h, w) - ) + "Required crop size {} is larger than input image size {}".format((th, tw), (h, w)) ) if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,)).item() - j = torch.randint(0, w - tw + 1, size=(1,)).item() + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() return i, j, th, tw def __init__( @@ -796,11 +728,9 @@ def __init__( ): super().__init__(reset_auto=reset_auto) - self.size = tuple( - _setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - ) - ) + self.size = tuple(_setup_size( + size, error_msg="Please provide only two dimensions (h, w) for size." + )) self.padding = padding self.pad_if_needed = pad_if_needed @@ -837,9 +767,7 @@ def _call(self, img, i, j, h, w): return F.crop(img, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + "(size={0}, padding={1})".format( - self.size, self.padding - ) + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) class RandomHorizontalFlip(Transform): @@ -873,7 +801,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p) class RandomVerticalFlip(Transform): @@ -907,7 +835,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p) class RandomPerspective(Transform): @@ -1006,59 +934,27 @@ def get_start_endpoints( half_height = height // 2 half_width = width // 2 topleft = [ - int( - torch.randint( - 0, int(distortion_scale * half_width) + 1, size=(1,) - ).item() - ), - int( - torch.randint( - 0, int(distortion_scale * half_height) + 1, size=(1,) - ).item() - ), + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) ] topright = [ - int( - torch.randint( - width - int(distortion_scale * half_width) - 1, width, size=(1,) - ).item() - ), - int( - torch.randint( - 0, int(distortion_scale * half_height) + 1, size=(1,) - ).item() - ), + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) ] botright = [ - int( - torch.randint( - width - int(distortion_scale * half_width) - 1, width, size=(1,) - ).item() - ), - int( - torch.randint( - height - int(distortion_scale * half_height) - 1, height, size=(1,) - ).item() - ), + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) ] botleft = [ - int( - torch.randint( - 0, int(distortion_scale * half_width) + 1, size=(1,) - ).item() - ), - int( - torch.randint( - height - int(distortion_scale * half_height) - 1, height, size=(1,) - ).item() - ), + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p) class RandomResizedCrop(Transform): @@ -1186,10 +1082,10 @@ def _call(self, img, i, j, h, w): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + "(size={0}".format(self.size) - format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) - format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) - format_string += ", interpolation={0})".format(interpolate_str) + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) return format_string @@ -1197,12 +1093,9 @@ class RandomSizedCrop(RandomResizedCrop): """ Note: This transform is deprecated in favor of RandomResizedCrop. """ - def __init__(self, *args, **kwargs): - warnings.warn( - "The use of the transforms.RandomSizedCrop transform is deprecated, " - + "please use transforms.RandomResizedCrop instead." - ) + warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead.") super(RandomSizedCrop, self).__init__(*args, **kwargs) @@ -1251,7 +1144,7 @@ def _call(self, img): return F.five_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + "(size={0})".format(self.size) + return self.__class__.__name__ + '(size={0})'.format(self.size) class TenCrop(Transform): @@ -1302,9 +1195,7 @@ def _call(self, img): return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): - return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format( - self.size, self.vertical_flip - ) + return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) class LinearTransformation(Transform): @@ -1329,25 +1220,17 @@ class LinearTransformation(Transform): def __init__(self, transformation_matrix, mean_vector): super().__init__() if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError( - "transformation_matrix should be square. Got " - + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()) - ) + raise ValueError("transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError( - "mean_vector should have the same length {}".format(mean_vector.size(0)) - + " as any one of the dimensions of the transformation_matrix [{}]".format( - tuple(transformation_matrix.size()) - ) - ) + raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]" + .format(tuple(transformation_matrix.size()))) if transformation_matrix.device != mean_vector.device: - raise ValueError( - "Input tensors should be on the same device. Got {} and {}".format( - transformation_matrix.device, mean_vector.device - ) - ) + raise ValueError("Input tensors should be on the same device. Got {} and {}" + .format(transformation_matrix.device, mean_vector.device)) self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1381,9 +1264,9 @@ def _call(self, tensor: Tensor) -> Tensor: return tensor def __repr__(self): - format_string = self.__class__.__name__ + "(transformation_matrix=" - format_string += str(self.transformation_matrix.tolist()) + ")" - format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")" + format_string = self.__class__.__name__ + '(transformation_matrix=' + format_string += (str(self.transformation_matrix.tolist()) + ')') + format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') return format_string @@ -1411,22 +1294,18 @@ class ColorJitter(Transform): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, reset_auto=True): super().__init__(reset_auto=reset_auto) - self.brightness = self._check_input(brightness, "brightness") - self.contrast = self._check_input(contrast, "contrast") - self.saturation = self._check_input(saturation, "saturation") + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input( - hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False + hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False ) @torch.jit.unused - def _check_input( - self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True - ): + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: - raise ValueError( - "If {} is a single number, it must be non negative.".format(name) - ) + raise ValueError("If {} is a single number, it must be non negative.".format(name)) value = [center - float(value), center + float(value)] if clip_first_on_zero: value[0] = max(value[0], 0.0) @@ -1434,11 +1313,7 @@ def _check_input( if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError("{} values should be between {}".format(name, bound)) else: - raise TypeError( - "{} should be a single number or a list/tuple with length 2.".format( - name - ) - ) + raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) # if value is 0 or (1., 1.) for brightness/contrast/saturation # or (0., 0.) for hue, do nothing @@ -1482,21 +1357,9 @@ def get_params( hue = self.hue fn_idx = torch.randperm(4) - b = ( - None - if brightness is None - else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - ) - c = ( - None - if contrast is None - else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - ) - s = ( - None - if saturation is None - else float(torch.empty(1).uniform_(saturation[0], saturation[1])) - ) + b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) return fn_idx, b, c, s, h @@ -1531,11 +1394,11 @@ def _call( return img def __repr__(self): - format_string = self.__class__.__name__ + "(" - format_string += "brightness={0}".format(self.brightness) - format_string += ", contrast={0}".format(self.contrast) - format_string += ", saturation={0}".format(self.saturation) - format_string += ", hue={0})".format(self.hue) + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) return format_string @@ -1593,10 +1456,10 @@ def __init__( ) interpolation = _interpolation_modes_from_int(interpolation) - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if center is not None: - _check_sequence_input(center, "center", req_sizes=(2,)) + _check_sequence_input(center, "center", req_sizes=(2, )) self.center = center @@ -1618,9 +1481,7 @@ def get_params(self, img, degrees: List[float] = []) -> float: """ if not len(degrees): degrees = self.degrees - angle = float( - torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item() - ) + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) return angle def _call(self, img, angle): @@ -1642,14 +1503,14 @@ def _call(self, img, angle): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees) - format_string += ", interpolation={0}".format(interpolate_str) - format_string += ", expand={0}".format(self.expand) + format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) + format_string += ', interpolation={0}'.format(interpolate_str) + format_string += ', expand={0}'.format(self.expand) if self.center is not None: - format_string += ", center={0}".format(self.center) + format_string += ', center={0}'.format(self.center) if self.fill is not None: - format_string += ", fill={0}".format(self.fill) - format_string += ")" + format_string += ', fill={0}'.format(self.fill) + format_string += ')' return format_string @@ -1723,17 +1584,17 @@ def __init__( ) fill = fillcolor - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if translate is not None: - _check_sequence_input(translate, "translate", req_sizes=(2,)) + _check_sequence_input(translate, "translate", req_sizes=(2, )) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: - _check_sequence_input(scale, "scale", req_sizes=(2,)) + _check_sequence_input(scale, "scale", req_sizes=(2, )) for s in scale: if s <= 0: raise ValueError("scale values should be positive") @@ -1778,9 +1639,7 @@ def get_params( if not len(img_size): img_size = F._get_image_size(img) - angle = float( - torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item() - ) + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) if translate is not None: max_dx = float(translate[0] * img_size[0]) max_dy = float(translate[1] * img_size[1]) @@ -1791,9 +1650,7 @@ def get_params( translations = (0, 0) if scale_ranges is not None: - scale = float( - torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item() - ) + scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) else: scale = 1.0 @@ -1832,20 +1689,20 @@ def _call(self, img, angle, translations, scale, shear): ) def __repr__(self): - s = "{name}(degrees={degrees}" + s = '{name}(degrees={degrees}' if self.translate is not None: - s += ", translate={translate}" + s += ', translate={translate}' if self.scale is not None: - s += ", scale={scale}" + s += ', scale={scale}' if self.shear is not None: - s += ", shear={shear}" + s += ', shear={shear}' if self.interpolation != InterpolationMode.NEAREST: - s += ", interpolation={interpolation}" + s += ', interpolation={interpolation}' if self.fill != 0: - s += ", fill={fill}" - s += ")" + s += ', fill={fill}' + s += ')' d = dict(self.__dict__) - d["interpolation"] = self.interpolation.value + d['interpolation'] = self.interpolation.value return s.format(name=self.__class__.__name__, **d) @@ -1880,9 +1737,7 @@ def _call(self, img): return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self): - return self.__class__.__name__ + "(num_output_channels={0})".format( - self.num_output_channels - ) + return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) class RandomGrayscale(Transform): @@ -1924,7 +1779,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={0})".format(self.p) + return self.__class__.__name__ + '(p={0})'.format(self.p) class RandomErasing(Transform): @@ -1966,9 +1821,7 @@ def __init__( ): super().__init__(reset_auto=reset_auto) if not isinstance(value, (numbers.Number, str, tuple, list)): - raise TypeError( - "Argument value should be either a number or str or a sequence" - ) + raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": raise ValueError("If value is str, it should be 'random'") if not isinstance(scale, (tuple, list)): @@ -2028,8 +1881,8 @@ def get_params_transform( else: v = torch.tensor(value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = torch.randint(0, img_h - h + 1, size=(1, )).item() + j = torch.randint(0, img_w - w + 1, size=(1, )).item() return i, j, h, w, v # Return original image @@ -2041,9 +1894,7 @@ def get_params(self, img): # cast self.value to script acceptable type if isinstance(self.value, (int, float)): - value = [ - self.value, - ] + value = [self.value, ] elif isinstance(self.value, str): value = None elif isinstance(self.value, tuple): @@ -2076,14 +1927,13 @@ def _call(self, img, r, x, y, h, w, v): return img def __repr__(self): - s = "(p={}, ".format(self.p) - s += "scale={}, ".format(self.scale) - s += "ratio={}, ".format(self.ratio) - s += "value={}, ".format(self.value) - s += "inplace={})".format(self.inplace) + s = '(p={}, '.format(self.p) + s += 'scale={}, '.format(self.scale) + s += 'ratio={}, '.format(self.ratio) + s += 'value={}, '.format(self.value) + s += 'inplace={})'.format(self.inplace) return self.__class__.__name__ + s - class GaussianBlur(Transform): stochastic = True """Blurs image with randomly chosen Gaussian blur. @@ -2104,14 +1954,10 @@ class GaussianBlur(Transform): def __init__(self, kernel_size, sigma=(0.1, 2.0), reset_auto=True): super().__init__(reset_auto=reset_auto) - self.kernel_size = _setup_size( - kernel_size, "Kernel size should be a tuple/list of two integers" - ) + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") for ks in self.kernel_size: if ks <= 0 or ks % 2 == 0: - raise ValueError( - "Kernel size value should be an odd and positive number." - ) + raise ValueError("Kernel size value should be an odd and positive number.") if isinstance(sigma, numbers.Number): if sigma <= 0: @@ -2119,13 +1965,9 @@ def __init__(self, kernel_size, sigma=(0.1, 2.0), reset_auto=True): sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: if not 0.0 < sigma[0] <= sigma[1]: - raise ValueError( - "sigma values should be positive and of the form (min, max)." - ) + raise ValueError("sigma values should be positive and of the form (min, max).") else: - raise ValueError( - "sigma should be a single number or a list/tuple with length 2." - ) + raise ValueError("sigma should be a single number or a list/tuple with length 2.") self.sigma = sigma @@ -2158,8 +2000,8 @@ def _call(self, img: Tensor, sigma: float) -> Tensor: return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self): - s = "(kernel_size={}, ".format(self.kernel_size) - s += "sigma={})".format(self.sigma) + s = '(kernel_size={}, '.format(self.kernel_size) + s += 'sigma={})'.format(self.sigma) return self.__class__.__name__ + s @@ -2177,21 +2019,17 @@ def _setup_size(size, error_msg): def _check_sequence_input(x, name, req_sizes): - msg = ( - req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) - ) + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) if not isinstance(x, Sequence): raise TypeError("{} should be a sequence of length {}.".format(name, msg)) if len(x) not in req_sizes: raise ValueError("{} should be sequence of length {}.".format(name, msg)) -def _setup_angle(x, name, req_sizes=(2,)): +def _setup_angle(x, name, req_sizes=(2, )): if isinstance(x, numbers.Number): if x < 0: - raise ValueError( - "If {} is a single number, it must be positive.".format(name) - ) + raise ValueError("If {} is a single number, it must be positive.".format(name)) x = [-x, x] else: _check_sequence_input(x, name, req_sizes) @@ -2231,7 +2069,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p) class RandomPosterize(Transform): @@ -2268,7 +2106,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p) + return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) class RandomSolarize(Transform): @@ -2305,9 +2143,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(threshold={},p={})".format( - self.threshold, self.p - ) + return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) class RandomAdjustSharpness(Transform): @@ -2343,9 +2179,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(sharpness_factor={},p={})".format( - self.sharpness_factor, self.p - ) + return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) class RandomAutocontrast(Transform): @@ -2379,7 +2213,7 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p) class RandomEqualize(Transform): @@ -2413,4 +2247,4 @@ def _call(self, img, r): return img def __repr__(self): - return self.__class__.__name__ + "(p={})".format(self.p) + return self.__class__.__name__ + '(p={})'.format(self.p)