diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index ed47fc1addf..e67bfe8bf88 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): def test_midlevel_normalize_output_type(): inpt = torch.rand(1, 3, 32, 32) - output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) + output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) assert isinstance(output, torch.Tensor) torch.testing.assert_close(inpt - 0.5, output) - inpt = make_segmentation_mask() - output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) - assert isinstance(output, features.SegmentationMask) - torch.testing.assert_close(inpt, output) - - inpt = make_bounding_box(format="XYXY") - output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) - assert isinstance(output, features.BoundingBox) - torch.testing.assert_close(inpt, output) - inpt = make_image(color_space=features.ColorSpace.RGB) - output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0)) + output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) assert isinstance(output, torch.Tensor) torch.testing.assert_close(inpt - 0.5, output) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 5c5e1e89da7..37ca7857674 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Callable, Dict, List, Sequence, Type, Union +from typing import Any, Callable, Dict, Sequence, Type, Union import PIL.Image @@ -10,6 +10,8 @@ from torchvision.prototype.transforms._utils import query_bounding_box from torchvision.transforms.transforms import _setup_size +from ._utils import is_simple_tensor + class Identity(Transform): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -91,10 +93,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class Normalize(Transform): - def __init__(self, mean: List[float], std: List[float]): + _transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor) + + def __init__(self, mean: Sequence[float], std: Sequence[float]): super().__init__() - self.mean = mean - self.std = std + self.mean = list(mean) + self.std = list(std) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.normalize(inpt, mean=self.mean, std=self.std) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index c20c5760d4a..77956902bd5 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -14,11 +14,11 @@ normalize_image_tensor = _FT.normalize -def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType: - if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image): - return inpt - elif isinstance(inpt, PIL.Image.Image): - raise TypeError("Unsupported input type") +def normalize( + inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False +) -> DType: + if not isinstance(inpt, torch.Tensor): + raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") else: # Image instance after normalization is not Image anymore due to unknown data range # Thus we return Tensor for input Image