diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3160770a09d..3618ecc285d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,6 +1,7 @@ import math import numbers import warnings +from types import SimpleNamespace from typing import Any, cast, Dict, List, Optional, Tuple, Union import PIL.Image @@ -83,14 +84,28 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: else: v = torch.tensor(self.value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = int(torch.randint(0, img_h - h + 1, size=(1,))) + j = int(torch.randint(0, img_w - w + 1, size=(1,))) break else: i, j, h, w, v = 0, 0, img_h, img_w, None return dict(i=i, j=j, h=h, w=w, v=v) + @staticmethod + def get_params( + image: torch.Tensor, + scale: Tuple[float, float], + ratio: Tuple[float, float], + value: Optional[List[float]] = None, + ) -> Tuple[int, int, int, int, torch.Tensor]: + self = SimpleNamespace(scale=scale, _log_ratio=torch.log(torch.tensor(ratio)), value=value) + params = RandomErasing._get_params(self, [image]) # type: ignore[arg-type] + v = params["v"] + if v is None: + v = image + return params["i"], params["j"], params["h"], params["w"], v + def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] ) -> Union[datapoints.ImageType, datapoints.VideoType]: