From dec9d45a1e1451103c07c50e3a9425c421314fee Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 Jan 2023 22:11:02 +0100 Subject: [PATCH 1/2] reinstate get_params for RandomErasing --- torchvision/prototype/transforms/_augment.py | 60 ++++++++++++++++---- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3160770a09d..d3969e0fb75 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -12,7 +12,7 @@ from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from ._transform import _RandomApplyTransform -from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size +from .utils import get_dimensions, has_any, is_simple_tensor, query_chw, query_spatial_size class RandomErasing(_RandomApplyTransform): @@ -53,19 +53,24 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - img_c, img_h, img_w = query_chw(flat_inputs) - - if self.value is not None and not (len(self.value) in (1, img_c)): + @staticmethod + def _get_params_internal( + img_c: int, + img_h: int, + img_w: int, + scale: Tuple[float, float], + log_ratio: torch.Tensor, + value: Optional[List[float]] = None, + ) -> Tuple[int, int, int, int, Optional[torch.Tensor]]: + if value is not None and not (len(value) in (1, img_c)): raise ValueError( f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w - log_ratio = self._log_ratio for _ in range(10): - erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] @@ -78,18 +83,49 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: if not (h < img_h and w < img_w): continue - if self.value is None: + if value is None: v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() else: - v = torch.tensor(self.value)[:, None, None] + v = torch.tensor(value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = int(torch.randint(0, img_h - h + 1, size=(1,))) + j = int(torch.randint(0, img_w - w + 1, size=(1,))) break else: i, j, h, w, v = 0, 0, img_h, img_w, None - return dict(i=i, j=j, h=h, w=w, v=v) + return i, j, h, w, v + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + img_c, img_h, img_w = query_chw(flat_inputs) + return dict( + zip( + "ijhwv", + self._get_params_internal( + img_c, + img_h, + img_w, + self.scale, + self._log_ratio, + self.value, # type: ignore[arg-type] + ), + ) + ) + + @staticmethod + def get_params( + image: torch.Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + value: Optional[List[float]] = None, + ) -> Tuple[int, int, int, int, torch.Tensor]: + img_c, img_h, img_w = get_dimensions(image) + i, j, h, w, v = RandomErasing._get_params_internal( + img_c, img_h, img_w, scale, torch.log(torch.tensor(ratio)), value + ) + if v is None: + v = image + return i, j, h, w, v def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] From 21946f7f2a76113d9d4157bc6db496a0cbf2ef56 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jan 2023 08:31:23 +0100 Subject: [PATCH 2/2] use fake self namespace --- torchvision/prototype/transforms/_augment.py | 51 ++++++-------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d3969e0fb75..3618ecc285d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,6 +1,7 @@ import math import numbers import warnings +from types import SimpleNamespace from typing import Any, cast, Dict, List, Optional, Tuple, Union import PIL.Image @@ -12,7 +13,7 @@ from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from ._transform import _RandomApplyTransform -from .utils import get_dimensions, has_any, is_simple_tensor, query_chw, query_spatial_size +from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size class RandomErasing(_RandomApplyTransform): @@ -53,24 +54,19 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) - @staticmethod - def _get_params_internal( - img_c: int, - img_h: int, - img_w: int, - scale: Tuple[float, float], - log_ratio: torch.Tensor, - value: Optional[List[float]] = None, - ) -> Tuple[int, int, int, int, Optional[torch.Tensor]]: - if value is not None and not (len(value) in (1, img_c)): + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + img_c, img_h, img_w = query_chw(flat_inputs) + + if self.value is not None and not (len(self.value) in (1, img_c)): raise ValueError( f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w + log_ratio = self._log_ratio for _ in range(10): - erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] @@ -83,10 +79,10 @@ def _get_params_internal( if not (h < img_h and w < img_w): continue - if value is None: + if self.value is None: v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() else: - v = torch.tensor(value)[:, None, None] + v = torch.tensor(self.value)[:, None, None] i = int(torch.randint(0, img_h - h + 1, size=(1,))) j = int(torch.randint(0, img_w - w + 1, size=(1,))) @@ -94,23 +90,7 @@ def _get_params_internal( else: i, j, h, w, v = 0, 0, img_h, img_w, None - return i, j, h, w, v - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - img_c, img_h, img_w = query_chw(flat_inputs) - return dict( - zip( - "ijhwv", - self._get_params_internal( - img_c, - img_h, - img_w, - self.scale, - self._log_ratio, - self.value, # type: ignore[arg-type] - ), - ) - ) + return dict(i=i, j=j, h=h, w=w, v=v) @staticmethod def get_params( @@ -119,13 +99,12 @@ def get_params( ratio: Tuple[float, float], value: Optional[List[float]] = None, ) -> Tuple[int, int, int, int, torch.Tensor]: - img_c, img_h, img_w = get_dimensions(image) - i, j, h, w, v = RandomErasing._get_params_internal( - img_c, img_h, img_w, scale, torch.log(torch.tensor(ratio)), value - ) + self = SimpleNamespace(scale=scale, _log_ratio=torch.log(torch.tensor(ratio)), value=value) + params = RandomErasing._get_params(self, [image]) # type: ignore[arg-type] + v = params["v"] if v is None: v = image - return i, j, h, w, v + return params["i"], params["j"], params["h"], params["w"], v def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]