From 1f37a1ee8e1f66ad0da58110b40cca2e415b229b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 26 Feb 2022 01:34:47 +0000 Subject: [PATCH 1/9] Replace get_image_size/num_channels with get_image_dims --- torchvision/prototype/transforms/_augment.py | 5 +- .../prototype/transforms/_auto_augment.py | 98 +++++++++---------- torchvision/prototype/transforms/_geometry.py | 2 +- .../transforms/functional/__init__.py | 2 +- .../transforms/functional/_geometry.py | 19 ++-- .../prototype/transforms/functional/_utils.py | 29 ++---- 6 files changed, 72 insertions(+), 83 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index ce198d39b33..3295e9e7f73 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -41,8 +41,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_c = F.get_image_num_channels(image) - img_w, img_h = F.get_image_size(image) + img_c, img_h, img_w = F.get_image_dims(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -138,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - W, H = F.get_image_size(image) + _, H, W = F.get_image_dims(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7eae25a681e..9fdca914aad 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -47,7 +47,7 @@ def dispatch( return input image = query_image(sample) - num_channels = F.get_image_num_channels(image) + num_channels, _, _ = F.get_image_dims(image) fill = self.fill if isinstance(fill, (int, float)): @@ -158,25 +158,25 @@ def transform(input: Any) -> Any: class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), - "Invert": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Invert": (lambda num_bins, height, width: None, False), } def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = F.get_image_dims(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, image_size) + magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -303,25 +303,25 @@ def forward(self, *inputs: Any) -> Any: class RandAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: @@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = F.get_image_dims(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -354,25 +354,25 @@ def forward(self, *inputs: Any) -> Any: class TrivialAugmentWide(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, image_size: None, False), - "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) .round() .int(), False, ), - "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), - "Equalize": (lambda num_bins, image_size: None, False), + "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), } def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): @@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = F.get_image_dims(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4c9d9192ac8..ed8814f06dc 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -109,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - width, height = F.get_image_size(image) + _, height, width = F.get_image_dims(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index c487aba7fa2..f32ab15bea1 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,5 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import get_image_size, get_image_num_channels # usort: skip +from ._utils import get_image_dims # usort: skip from ._meta_conversion import ( convert_bounding_box_format, convert_image_color_space_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 76564fdd54d..c190ff5086c 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -5,7 +5,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import InterpolationMode -from torchvision.prototype.transforms.functional import get_image_size +from torchvision.prototype.transforms.functional import get_image_dims from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix @@ -40,8 +40,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - old_width, old_height = _FT.get_image_size(image) - num_channels = _FT.get_image_num_channels(image) + num_channels, old_height, old_width = get_image_dims(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -143,7 +142,7 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - width, height = get_image_size(img) + _, height, width = get_image_dims(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -169,7 +168,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - width, height = get_image_size(img) + _, height, width = get_image_dims(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -186,7 +185,7 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - width, height = get_image_size(img) + _, height, width = get_image_dims(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -262,13 +261,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_image_dims(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_image_dims(img) if crop_width == image_width and crop_height == image_height: return img @@ -278,13 +277,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_image_dims(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_image_dims(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 07235d63716..69db47c7fb5 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,29 +1,20 @@ -from typing import Tuple, Union, cast +from typing import Tuple, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: +def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): + channels = image.num_channels height, width = image.image_size - return width, height elif isinstance(image, torch.Tensor): - return cast(Tuple[int, int], tuple(_FT.get_image_size(image))) - if isinstance(image, PIL.Image.Image): - return cast(Tuple[int, int], tuple(_FP.get_image_size(image))) + channels = 1 if image.ndim == 2 else image.shape[-3] + height, width = image.shape[-2:] + elif isinstance(image, PIL.Image.Image): + channels = len(image.getbands()) + width, height = image.size else: - raise TypeError(f"unable to get image size from object of type {type(image).__name__}") - - -def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: - if isinstance(image, features.Image): - return image.num_channels - elif isinstance(image, torch.Tensor): - return _FT.get_image_num_channels(image) - if isinstance(image, PIL.Image.Image): - return cast(int, _FP.get_image_num_channels(image)) - else: - raise TypeError(f"unable to get num channels from object of type {type(image).__name__}") + raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") + return channels, height, width From ead34b3ea23ae207f3667569fe3ce22c42b3c841 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 26 Feb 2022 01:58:09 +0000 Subject: [PATCH 2/9] Reduce verbosity --- .../prototype/transforms/_auto_augment.py | 90 +++++++++---------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 9fdca914aad..fe8bfb55735 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -158,25 +158,25 @@ def transform(input: Any) -> Any: class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True), - "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True), - "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, height, width: None, False), - "Equalize": (lambda num_bins, height, width: None, False), - "Invert": (lambda num_bins, height, width: None, False), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + "Invert": (lambda num_bins, image_size: None, False), } def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: @@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, height, width) + magnitudes = magnitudes_fn(10, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -303,25 +303,25 @@ def forward(self, *inputs: Any) -> Any: class RandAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, height, width: None, False), - "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True), - "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True), - "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, height, width: None, False), - "Equalize": (lambda num_bins, height, width: None, False), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), } def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: @@ -339,7 +339,7 @@ def forward(self, *inputs: Any) -> Any: for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -354,25 +354,25 @@ def forward(self, *inputs: Any) -> Any: class TrivialAugmentWide(_AutoAugmentBase): _AUGMENTATION_SPACE = { - "Identity": (lambda num_bins, height, width: None, False), - "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), - "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) .round() .int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, height, width: None, False), - "Equalize": (lambda num_bins, height, width: None, False), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), } def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): @@ -387,7 +387,7 @@ def forward(self, *inputs: Any) -> Any: transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: From c8f7b14ee778daabe4a6d2494e8d2c6db5bb9e8c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 26 Feb 2022 10:47:28 +0000 Subject: [PATCH 3/9] Fix JIT-scriptability --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c190ff5086c..75d5eb9ed37 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -40,7 +40,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - num_channels, old_height, old_width = get_image_dims(image) + num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), From f7513a4d4ee48f9eecec6c06e43446fcac06272e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 26 Feb 2022 11:24:11 +0000 Subject: [PATCH 4/9] Refactoring --- torchvision/prototype/transforms/_augment.py | 6 +++--- .../prototype/transforms/_auto_augment.py | 10 +++++----- torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/prototype/transforms/_utils.py | 16 ++++++++++++++- .../transforms/functional/__init__.py | 1 - .../transforms/functional/_geometry.py | 17 ++++++++-------- .../prototype/transforms/functional/_utils.py | 20 ------------------- torchvision/transforms/functional_pil.py | 11 +++++++++- torchvision/transforms/functional_tensor.py | 7 +++++++ 9 files changed, 50 insertions(+), 42 deletions(-) delete mode 100644 torchvision/prototype/transforms/functional/_utils.py diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3295e9e7f73..b448e75a154 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image +from ._utils import query_image, get_image_dims class RandomErasing(Transform): @@ -41,7 +41,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_c, img_h, img_w = F.get_image_dims(image) + img_c, img_h, img_w = get_image_dims(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -137,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - _, H, W = F.get_image_dims(image) + _, H, W = get_image_dims(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index fe8bfb55735..cb75f4301ac 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively -from ._utils import query_image +from ._utils import query_image, get_image_dims K = TypeVar("K") V = TypeVar("V") @@ -47,7 +47,7 @@ def dispatch( return input image = query_image(sample) - num_channels, _, _ = F.get_image_dims(image) + num_channels, *_ = get_image_dims(image) fill = self.fill if isinstance(fill, (int, float)): @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = F.get_image_dims(image) + _, height, width = get_image_dims(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -334,7 +334,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = F.get_image_dims(image) + _, height, width = get_image_dims(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -383,7 +383,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = F.get_image_dims(image) + _, height, width = get_image_dims(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index ed8814f06dc..a54c65a67e5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image +from ._utils import query_image, get_image_dims class HorizontalFlip(Transform): @@ -109,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - _, height, width = F.get_image_dims(image) + _, height, width = get_image_dims(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 24d794a2cb4..16c1ed40a7f 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,9 +1,10 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Union import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: @@ -17,3 +18,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima return next(query_recursively(fn, sample)) except StopIteration: raise TypeError("No image was found in the sample") + + +def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: + if isinstance(image, features.Image): + channels = image.num_channels + height, width = image.image_size + elif isinstance(image, torch.Tensor): + channels, height, width = _FT.get_image_dims(image) + elif isinstance(image, PIL.Image.Image): + channels, height, width = _FP.get_image_dims(image) + else: + raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") + return channels, height, width diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index f32ab15bea1..bf44c9aec7c 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,4 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import get_image_dims # usort: skip from ._meta_conversion import ( convert_bounding_box_format, convert_image_color_space_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 75d5eb9ed37..e4af798b481 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -5,7 +5,6 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import InterpolationMode -from torchvision.prototype.transforms.functional import get_image_dims from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix @@ -40,7 +39,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - num_channels, old_height, old_width = image.shape[-3:] + num_channels, old_height, old_width = _FT.get_image_dims(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -142,7 +141,7 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - _, height, width = get_image_dims(img) + _, height, width = _FT.get_image_dims(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -168,7 +167,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - _, height, width = get_image_dims(img) + _, height, width = _FP.get_image_dims(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -185,7 +184,7 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - _, height, width = get_image_dims(img) + _, height, width = _FT.get_image_dims(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -261,13 +260,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_image_dims(img) + _, image_height, image_width = _FT.get_image_dims(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - _, image_height, image_width = get_image_dims(img) + _, image_height, image_width = _FT.get_image_dims(img) if crop_width == image_width and crop_height == image_height: return img @@ -277,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = get_image_dims(img) + _, image_height, image_width = _FP.get_image_dims(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - _, image_height, image_width = get_image_dims(img) + _, image_height, image_width = _FP.get_image_dims(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py deleted file mode 100644 index 69db47c7fb5..00000000000 --- a/torchvision/prototype/transforms/functional/_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Tuple, Union - -import PIL.Image -import torch -from torchvision.prototype import features - - -def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: - if isinstance(image, features.Image): - channels = image.num_channels - height, width = image.image_size - elif isinstance(image, torch.Tensor): - channels = 1 if image.ndim == 2 else image.shape[-3] - height, width = image.shape[-2:] - elif isinstance(image, PIL.Image.Image): - channels = len(image.getbands()) - width, height = image.size - else: - raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") - return channels, height, width diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 01c321dabfa..0564755f36f 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool: return isinstance(img, Image.Image) +@torch.jit.unused +def get_image_dims(img: Any) -> List[int]: + if _is_pil_image(img): + channels = len(img.getbands()) + width, height = img.size + return [channels, height, width] + raise TypeError(f"Unexpected type {type(img)}") + + @torch.jit.unused def get_image_size(img: Any) -> List[int]: if _is_pil_image(img): @@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]: @torch.jit.unused def get_image_num_channels(img: Any) -> int: if _is_pil_image(img): - return 1 if img.mode == "L" else 3 + return len(img.getbands()) raise TypeError(f"Unexpected type {type(img)}") diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index fae681b3aa9..7616dcc86f7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None: raise TypeError("Threshold should be less than bound of img.") +def get_image_dims(img: Tensor) -> List[int]: + _assert_image_tensor(img) + channels = 1 if img.ndim == 2 else img.shape[-3] + height, width = img.shape[-2:] + return [channels, height, width] + + def get_image_size(img: Tensor) -> List[int]: # Returns (w, h) of tensor image _assert_image_tensor(img) From 04a3eabc326c6f2acde8b3a72690b5e1c017dbb1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 26 Feb 2022 11:51:47 +0000 Subject: [PATCH 5/9] More refactoring --- torchvision/prototype/transforms/_augment.py | 6 +++--- .../prototype/transforms/_auto_augment.py | 10 +++++----- torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/prototype/transforms/_utils.py | 6 +++--- .../prototype/transforms/functional/_geometry.py | 16 ++++++++-------- .../prototype/transforms/functional/_misc.py | 5 ++++- torchvision/transforms/functional_pil.py | 2 +- torchvision/transforms/functional_tensor.py | 2 +- 8 files changed, 27 insertions(+), 24 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index b448e75a154..5862c6a06dc 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image, get_image_dims +from ._utils import query_image, get_image_dimensions class RandomErasing(Transform): @@ -41,7 +41,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_c, img_h, img_w = get_image_dims(image) + img_c, img_h, img_w = get_image_dimensions(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -137,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - _, H, W = get_image_dims(image) + _, H, W = get_image_dimensions(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index cb75f4301ac..78cbf958ccd 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively -from ._utils import query_image, get_image_dims +from ._utils import query_image, get_image_dimensions K = TypeVar("K") V = TypeVar("V") @@ -47,7 +47,7 @@ def dispatch( return input image = query_image(sample) - num_channels, *_ = get_image_dims(image) + num_channels, *_ = get_image_dimensions(image) fill = self.fill if isinstance(fill, (int, float)): @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = get_image_dims(image) + _, height, width = get_image_dimensions(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -334,7 +334,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = get_image_dims(image) + _, height, width = get_image_dimensions(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -383,7 +383,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - _, height, width = get_image_dims(image) + _, height, width = get_image_dimensions(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index a54c65a67e5..c58f26a0e06 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image, get_image_dims +from ._utils import query_image, get_image_dimensions class HorizontalFlip(Transform): @@ -109,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - _, height, width = get_image_dims(image) + _, height, width = get_image_dimensions(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 16c1ed40a7f..b66a732b740 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -20,14 +20,14 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima raise TypeError("No image was found in the sample") -def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: +def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): channels = image.num_channels height, width = image.image_size elif isinstance(image, torch.Tensor): - channels, height, width = _FT.get_image_dims(image) + channels, height, width = _FT.get_dimensions(image) elif isinstance(image, PIL.Image.Image): - channels, height, width = _FP.get_image_dims(image) + channels, height, width = _FP.get_dimensions(image) else: raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") return channels, height, width diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index e4af798b481..476de370b6f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -39,7 +39,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - num_channels, old_height, old_width = _FT.get_image_dims(image) + num_channels, old_height, old_width = _FT.get_dimensions(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -141,7 +141,7 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - _, height, width = _FT.get_image_dims(img) + _, height, width = _FT.get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -167,7 +167,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - _, height, width = _FP.get_image_dims(img) + _, height, width = _FP.get_dimensions(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -184,7 +184,7 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - _, height, width = _FT.get_image_dims(img) + _, height, width = _FT.get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -260,13 +260,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = _FT.get_image_dims(img) + _, image_height, image_width = _FT.get_dimensions(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - _, image_height, image_width = _FT.get_image_dims(img) + _, image_height, image_width = _FT.get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img @@ -276,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = _FP.get_image_dims(img) + _, image_height, image_width = _FP.get_dimensions(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - _, image_height, image_width = _FP.get_image_dims(img) + _, image_height, image_width = _FP.get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index fd0507cca4d..b72635193c8 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -2,10 +2,13 @@ import PIL.Image import torch -from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms.functional import to_tensor, to_pil_image +get_dimensions_image_tensor = _FT.get_dimensions +get_dimensions_image_pil = _FP.get_dimensions + normalize_image_tensor = _FT.normalize diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0564755f36f..5e383ff3286 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -21,7 +21,7 @@ def _is_pil_image(img: Any) -> bool: @torch.jit.unused -def get_image_dims(img: Any) -> List[int]: +def get_dimensions(img: Any) -> List[int]: if _is_pil_image(img): channels = len(img.getbands()) width, height = img.size diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 7616dcc86f7..b51a3a4d8b0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -21,7 +21,7 @@ def _assert_threshold(img: Tensor, threshold: float) -> None: raise TypeError("Threshold should be less than bound of img.") -def get_image_dims(img: Tensor) -> List[int]: +def get_dimensions(img: Tensor) -> List[int]: _assert_image_tensor(img) channels = 1 if img.ndim == 2 else img.shape[-3] height, width = img.shape[-2:] From 253e543cc7b90c612554fab5c24e1a3bbe3fad0e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Feb 2022 09:44:29 +0000 Subject: [PATCH 6/9] Replace all _FP/_FT direct calls. --- torchvision/prototype/transforms/__init__.py | 2 +- .../{_meta_conversion.py => _meta.py} | 0 torchvision/prototype/transforms/_utils.py | 6 +++--- .../transforms/functional/__init__.py | 2 +- .../transforms/functional/_geometry.py | 18 +++++++++--------- .../{_meta_conversion.py => _meta.py} | 4 ++++ .../prototype/transforms/functional/_misc.py | 5 +---- 7 files changed, 19 insertions(+), 18 deletions(-) rename torchvision/prototype/transforms/{_meta_conversion.py => _meta.py} (100%) rename torchvision/prototype/transforms/functional/{_meta_conversion.py => _meta.py} (96%) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 73235720d58..73d45097a93 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,7 +8,7 @@ from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop -from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace +from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta.py similarity index 100% rename from torchvision/prototype/transforms/_meta_conversion.py rename to torchvision/prototype/transforms/_meta.py diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index b66a732b740..93d73a33fca 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -4,7 +4,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: @@ -25,9 +25,9 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im channels = image.num_channels height, width = image.image_size elif isinstance(image, torch.Tensor): - channels, height, width = _FT.get_dimensions(image) + channels, height, width = get_dimensions_image_tensor(image) elif isinstance(image, PIL.Image.Image): - channels, height, width = _FP.get_dimensions(image) + channels, height, width = get_dimensions_image_pil(image) else: raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") return channels, height, width diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index bf44c9aec7c..e3fe60a7919 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,5 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._meta_conversion import ( +from ._meta import ( convert_bounding_box_format, convert_image_color_space_tensor, convert_image_color_space_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 476de370b6f..73c8ac5d0de 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -8,7 +8,7 @@ from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix -from ._meta_conversion import convert_bounding_box_format +from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil horizontal_flip_image_tensor = _FT.hflip @@ -39,7 +39,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - num_channels, old_height, old_width = _FT.get_dimensions(image) + num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -141,7 +141,7 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - _, height, width = _FT.get_dimensions(img) + _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -167,7 +167,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - _, height, width = _FP.get_dimensions(img) + _, height, width = get_dimensions_image_pil(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -184,7 +184,7 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - _, height, width = _FT.get_dimensions(img) + _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -260,13 +260,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = _FT.get_dimensions(img) + _, image_height, image_width = get_dimensions_image_tensor(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - _, image_height, image_width = _FT.get_dimensions(img) + _, image_height, image_width = get_dimensions_image_tensor(img) if crop_width == image_width and crop_height == image_height: return img @@ -276,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - _, image_height, image_width = _FP.get_dimensions(img) + _, image_height, image_width = get_dimensions_image_pil(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - _, image_height, image_width = _FP.get_dimensions(img) + _, image_height, image_width = get_dimensions_image_pil(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta.py similarity index 96% rename from torchvision/prototype/transforms/functional/_meta_conversion.py rename to torchvision/prototype/transforms/functional/_meta.py index b260beaa361..6ecb5aff257 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -4,6 +4,10 @@ from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +get_dimensions_image_tensor = _FT.get_dimensions +get_dimensions_image_pil = _FP.get_dimensions + + def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: xyxy = xywh.clone() xyxy[..., 2:] += xyxy[..., :2] diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index b72635193c8..fd0507cca4d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -2,13 +2,10 @@ import PIL.Image import torch -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import to_tensor, to_pil_image -get_dimensions_image_tensor = _FT.get_dimensions -get_dimensions_image_pil = _FP.get_dimensions - normalize_image_tensor = _FT.normalize From 957a6ac009921047c22c03b51a72b87c7f6d0029 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Feb 2022 10:20:55 +0000 Subject: [PATCH 7/9] Remove usages of get_image_size and get_image_num_channels from code-base. --- docs/source/transforms.rst | 1 + references/classification/transforms.py | 2 +- references/detection/transforms.py | 10 +++--- test/test_functional_tensor.py | 4 ++- torchvision/prototype/transforms/_utils.py | 1 + torchvision/transforms/autoaugment.py | 30 +++++++++-------- torchvision/transforms/functional.py | 37 +++++++++++++++------ torchvision/transforms/functional_tensor.py | 14 ++++---- torchvision/transforms/transforms.py | 20 ++++++----- 9 files changed, 73 insertions(+), 46 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index cae53728b96..5909b68966b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi erase five_crop gaussian_blur + get_dimensions get_image_num_channels get_image_size hflip diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 892b4e7e6c0..e72cd67fbfd 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -141,7 +141,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - W, H = F.get_image_size(batch) + _, H, W = F.get_dimensions(batch) r_x = torch.randint(W, (1,)) r_y = torch.randint(H, (1,)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 342d491ceb1..3a9bad78c25 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -34,7 +34,7 @@ def forward( if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: - width, _ = F.get_image_size(image) + _, _, width = F.get_dimensions(image) target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] if "masks" in target: target["masks"] = target["masks"].flip(-1) @@ -107,7 +107,7 @@ def forward( elif image.ndimension() == 2: image = image.unsqueeze(0) - orig_w, orig_h = F.get_image_size(image) + _, orig_h, orig_w = F.get_dimensions(image) while True: # sample an option @@ -192,7 +192,7 @@ def forward( if torch.rand(1) >= self.p: return image, target - orig_w, orig_h = F.get_image_size(image) + _, orig_h, orig_w = F.get_dimensions(image) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) @@ -270,7 +270,7 @@ def forward( image = self._contrast(image) if r[6] < self.p: - channels = F.get_image_num_channels(image) + channels, _, _ = F.get_dimensions(image) permutation = torch.randperm(channels) is_pil = F._is_pil_image(image) @@ -317,7 +317,7 @@ def forward( elif image.ndimension() == 2: image = image.unsqueeze(0) - orig_width, orig_height = F.get_image_size(image) + _, orig_height, orig_width = F.get_dimensions(image) r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) new_width = int(self.target_size[1] * r) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0ac559565b7..3bdf0cfe34e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels]) +@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions]) def test_image_sizes(device, fn): script_F = torch.jit.script(fn) @@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode): @pytest.mark.parametrize( "func, args", [ + (F_t.get_dimensions, ()), (F_t.get_image_size, ()), + (F_t.get_image_num_channels, ()), (F_t.vflip, ()), (F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 93d73a33fca..d8677d451c8 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -4,6 +4,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively + from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index d820e5126a1..b535e960611 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -225,8 +225,8 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True), @@ -260,15 +260,16 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: AutoAugmented image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] transform_id, probs, signs = self.get_params(len(self.policies)) - op_meta = self._augmentation_space(10, F.get_image_size(img)) + op_meta = self._augmentation_space(10, (height, width)) for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: magnitudes, signed = op_meta[op_name] @@ -323,8 +324,8 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, "Identity": (torch.tensor(0.0), False), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True), @@ -344,13 +345,14 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] - op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width)) for _ in range(self.num_ops): op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] @@ -429,9 +431,10 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] @@ -508,8 +511,8 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), - "TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), @@ -547,16 +550,17 @@ def forward(self, orig_img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(orig_img) if isinstance(orig_img, Tensor): img = orig_img if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] else: img = self._pil_to_tensor(orig_img) - op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) + op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width)) orig_dims = list(img.shape) batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 41c6ceada03..3ed1f9bfb48 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -59,6 +59,23 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode: _is_pil_image = F_pil._is_pil_image +def get_dimensions(img: Tensor) -> List[int]: + """Returns the dimensions of an image as [channels, height, width]. + + Args: + img (PIL Image or Tensor): The image to be checked. + + Returns: + List[int]: The image dimensions. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(get_dimensions) + if isinstance(img, torch.Tensor): + return F_t.get_dimensions(img) + + return F_pil.get_dimensions(img) + + def get_image_size(img: Tensor) -> List[int]: """Returns the size of an image as [width, height]. @@ -512,7 +529,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: output_size = (output_size[0], output_size[0]) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) crop_height, crop_width = output_size if crop_width > image_width or crop_height > image_height: @@ -523,7 +540,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, ] img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img @@ -721,7 +738,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten if len(size) != 2: raise ValueError("Please provide only two dimensions (h, w) for size.") - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) crop_height, crop_width = size if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" @@ -1047,9 +1064,9 @@ def rotate( center_f = [0.0, 0.0] if center is not None: - img_size = get_image_size(img) + _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -1167,22 +1184,22 @@ def affine( if center is not None and not isinstance(center, (list, tuple)): raise TypeError("Argument center should be a sequence") - img_size = get_image_size(img) + _, height, width = get_dimensions(img) if not isinstance(img, torch.Tensor): - # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # center = (width * 0.5 + 0.5, height * 0.5 + 0.5) # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - center = [img_size[0] * 0.5, img_size[1] * 0.5] + center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) center_f = [0.0, 0.0] if center is not None: - img_size = get_image_size(img) + _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b51a3a4d8b0..21d52c62a2b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -62,7 +62,7 @@ def _max_value(dtype: torch.dtype) -> float: def _assert_channels(img: Tensor, permitted: List[int]) -> None: - c = get_image_num_channels(img) + c = get_dimensions(img)[0] if c not in permitted: raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") @@ -134,7 +134,7 @@ def hflip(img: Tensor) -> Tensor: def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: _assert_image_tensor(img) - w, h = get_image_size(img) + _, h, w = get_dimensions(img) right = left + width bottom = top + height @@ -182,7 +182,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: _assert_image_tensor(img) _assert_channels(img, [3, 1]) - c = get_image_num_channels(img) + c = get_dimensions(img)[0] dtype = img.dtype if torch.is_floating_point(img) else torch.float32 if c == 3: mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) @@ -202,7 +202,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: _assert_image_tensor(img) _assert_channels(img, [1, 3]) - if get_image_num_channels(img) == 1: # Match PIL behaviour + if get_dimensions(img)[0] == 1: # Match PIL behaviour return img orig_dtype = img.dtype @@ -229,7 +229,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: _assert_channels(img, [1, 3]) - if get_image_num_channels(img) == 1: # Match PIL behaviour + if get_dimensions(img)[0] == 1: # Match PIL behaviour return img return _blend(img, rgb_to_grayscale(img), saturation_factor) @@ -458,7 +458,7 @@ def resize( if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - w, h = get_image_size(img) + _, h, w = get_dimensions(img) if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge short, long = (w, h) if w <= h else (h, w) @@ -525,7 +525,7 @@ def _assert_grid_transform_inputs( warnings.warn("Argument fill should be either int, float, tuple or list") # Check fill - num_channels = get_image_num_channels(img) + num_channels = get_dimensions(img)[0] if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): msg = ( "The number of elements in 'fill' cannot broadcast to match the number of " diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9fc79c1d8cc..37556cd4984 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -628,7 +628,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - w, h = F.get_image_size(img) + _, h, w = F.get_dimensions(img) th, tw = output_size if h + 1 < th or w + 1 < tw: @@ -663,7 +663,7 @@ def forward(self, img): if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) - width, height = F.get_image_size(img) + _, height, width = F.get_dimensions(img) # pad the width if needed if self.pad_if_needed and width < self.size[1]: padding = [self.size[1] - width, 0] @@ -793,14 +793,14 @@ def forward(self, img): """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] if torch.rand(1) < self.p: - width, height = F.get_image_size(img) startpoints, endpoints = self.get_params(width, height, self.distortion_scale) return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @@ -910,7 +910,7 @@ def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ - width, height = F.get_image_size(img) + _, height, width = F.get_dimensions(img) area = height * width log_ratio = torch.log(torch.tensor(ratio)) @@ -1339,9 +1339,10 @@ def forward(self, img): PIL Image or Tensor: Rotated image. """ fill = self.fill + channels, _, _ = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] angle = self.get_params(self.degrees) @@ -1519,13 +1520,14 @@ def forward(self, img): PIL Image or Tensor: Affine transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] - img_size = F.get_image_size(img) + img_size = [width, height] # flip for keeping BC on get_params call ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) @@ -1608,7 +1610,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly grayscaled image. """ - num_output_channels = F.get_image_num_channels(img) + num_output_channels, _, _ = F.get_dimensions(img) if torch.rand(1) < self.p: return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img From 8c35243b9450664a6b6a94359f29b900141ecb1f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Feb 2022 10:46:20 +0000 Subject: [PATCH 8/9] Fix JIT issues --- torchvision/prototype/transforms/functional/_geometry.py | 4 ++-- torchvision/transforms/autoaugment.py | 6 +++--- torchvision/transforms/functional.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 73c8ac5d0de..080fe5da891 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -143,7 +143,7 @@ def affine_image_tensor( if center is not None: _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) @@ -186,7 +186,7 @@ def rotate_image_tensor( if center is not None: _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index b535e960611..357e5bf250e 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -220,7 +220,7 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), @@ -318,7 +318,7 @@ def __init__( self.interpolation = interpolation self.fill = fill - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) "Identity": (torch.tensor(0.0), False), @@ -506,7 +506,7 @@ def __init__( self.interpolation = interpolation self.fill = fill - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: s = { # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3ed1f9bfb48..b2fc3f44f55 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1066,7 +1066,7 @@ def rotate( if center is not None: _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -1199,7 +1199,7 @@ def affine( if center is not None: _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) From 971d6e53e7dc7bb1100702482e1ac53a46e10335 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Feb 2022 11:33:08 +0000 Subject: [PATCH 9/9] Adding missing assertion. --- torchvision/transforms/functional_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 21d52c62a2b..18b2c721f4e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -35,6 +35,7 @@ def get_image_size(img: Tensor) -> List[int]: def get_image_num_channels(img: Tensor) -> int: + _assert_image_tensor(img) if img.ndim == 2: return 1 elif img.ndim > 2: