diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index ffabb91471c..69238760be5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -234,7 +234,18 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: + ) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 543a3bbc725..0a50c956f8e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1964,8 +1964,6 @@ def five_crop( if not torch.jit.is_scripting(): _log_api_usage_once(five_crop) - # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with - # `ten_crop` if torch.jit.is_scripting() or is_simple_tensor(inpt): return five_crop_image_tensor(inpt, size) elif isinstance(inpt, datapoints.Image): @@ -1983,40 +1981,90 @@ def five_crop( ) -def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: - tl, tr, bl, br, center = five_crop_image_tensor(image, size) +def ten_crop_image_tensor( + image: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + non_flipped = five_crop_image_tensor(image, size) if vertical_flip: image = vertical_flip_image_tensor(image) else: image = horizontal_flip_image_tensor(image) - tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size) + flipped = five_crop_image_tensor(image, size) - return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] + return non_flipped + flipped @torch.jit.unused -def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: - tl, tr, bl, br, center = five_crop_image_pil(image, size) +def ten_crop_image_pil( + image: PIL.Image.Image, size: List[int], vertical_flip: bool = False +) -> Tuple[ + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, +]: + non_flipped = five_crop_image_pil(image, size) if vertical_flip: image = vertical_flip_image_pil(image) else: image = horizontal_flip_image_pil(image) - tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size) - - return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] - - -def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: + flipped = five_crop_image_pil(image, size) + + return non_flipped + flipped + + +def ten_crop_video( + video: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) def ten_crop( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False -) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: +) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, +]: if not torch.jit.is_scripting(): _log_api_usage_once(ten_crop) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index beeb02cd915..c5b2a71d0d7 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -827,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten return tl, tr, bl, br, center -def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: +def ten_crop( + img: Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Generate ten cropped images from the given image. Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index d7858353be9..90cb0374eee 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1049,7 +1049,7 @@ class TenCrop(torch.nn.Module): Example: >>> transform = Compose([ - >>> TenCrop(size), # this is a list of PIL Images + >>> TenCrop(size), # this is a tuple of PIL Images >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: