Skip to content

Commit 21946f7

Browse files
committed
use fake self namespace
1 parent 35860ff commit 21946f7

File tree

1 file changed

+15
-36
lines changed

1 file changed

+15
-36
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import numbers
33
import warnings
4+
from types import SimpleNamespace
45
from typing import Any, cast, Dict, List, Optional, Tuple, Union
56

67
import PIL.Image
@@ -12,7 +13,7 @@
1213
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
1314

1415
from ._transform import _RandomApplyTransform
15-
from .utils import get_dimensions, has_any, is_simple_tensor, query_chw, query_spatial_size
16+
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
1617

1718

1819
class RandomErasing(_RandomApplyTransform):
@@ -53,24 +54,19 @@ def __init__(
5354

5455
self._log_ratio = torch.log(torch.tensor(self.ratio))
5556

56-
@staticmethod
57-
def _get_params_internal(
58-
img_c: int,
59-
img_h: int,
60-
img_w: int,
61-
scale: Tuple[float, float],
62-
log_ratio: torch.Tensor,
63-
value: Optional[List[float]] = None,
64-
) -> Tuple[int, int, int, int, Optional[torch.Tensor]]:
65-
if value is not None and not (len(value) in (1, img_c)):
57+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
58+
img_c, img_h, img_w = query_chw(flat_inputs)
59+
60+
if self.value is not None and not (len(self.value) in (1, img_c)):
6661
raise ValueError(
6762
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
6863
)
6964

7065
area = img_h * img_w
7166

67+
log_ratio = self._log_ratio
7268
for _ in range(10):
73-
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
69+
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
7470
aspect_ratio = torch.exp(
7571
torch.empty(1).uniform_(
7672
log_ratio[0], # type: ignore[arg-type]
@@ -83,34 +79,18 @@ def _get_params_internal(
8379
if not (h < img_h and w < img_w):
8480
continue
8581

86-
if value is None:
82+
if self.value is None:
8783
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
8884
else:
89-
v = torch.tensor(value)[:, None, None]
85+
v = torch.tensor(self.value)[:, None, None]
9086

9187
i = int(torch.randint(0, img_h - h + 1, size=(1,)))
9288
j = int(torch.randint(0, img_w - w + 1, size=(1,)))
9389
break
9490
else:
9591
i, j, h, w, v = 0, 0, img_h, img_w, None
9692

97-
return i, j, h, w, v
98-
99-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
100-
img_c, img_h, img_w = query_chw(flat_inputs)
101-
return dict(
102-
zip(
103-
"ijhwv",
104-
self._get_params_internal(
105-
img_c,
106-
img_h,
107-
img_w,
108-
self.scale,
109-
self._log_ratio,
110-
self.value, # type: ignore[arg-type]
111-
),
112-
)
113-
)
93+
return dict(i=i, j=j, h=h, w=w, v=v)
11494

11595
@staticmethod
11696
def get_params(
@@ -119,13 +99,12 @@ def get_params(
11999
ratio: Tuple[float, float],
120100
value: Optional[List[float]] = None,
121101
) -> Tuple[int, int, int, int, torch.Tensor]:
122-
img_c, img_h, img_w = get_dimensions(image)
123-
i, j, h, w, v = RandomErasing._get_params_internal(
124-
img_c, img_h, img_w, scale, torch.log(torch.tensor(ratio)), value
125-
)
102+
self = SimpleNamespace(scale=scale, _log_ratio=torch.log(torch.tensor(ratio)), value=value)
103+
params = RandomErasing._get_params(self, [image]) # type: ignore[arg-type]
104+
v = params["v"]
126105
if v is None:
127106
v = image
128-
return i, j, h, w, v
107+
return params["i"], params["j"], params["h"], params["w"], v
129108

130109
def _transform(
131110
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]

0 commit comments

Comments
 (0)