Skip to content

Commit 69c2a08

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Fix kernel passthrough and types of Normalize (#6490)
Summary: * Fix pass-through and supported types of Normalize * update error message on kernel * Fix linter. * Fix the tests. * Update type. * Update type. * Remove unnecessary tests for bboxes and masks. Reviewed By: NicolasHug Differential Revision: D39131017 fbshipit-source-id: d7847f4974b083395022471ba33dff0dbf7c9c55
1 parent 0ace8b5 commit 69c2a08

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
18441844

18451845
def test_midlevel_normalize_output_type():
18461846
inpt = torch.rand(1, 3, 32, 32)
1847-
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
1847+
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
18481848
assert isinstance(output, torch.Tensor)
18491849
torch.testing.assert_close(inpt - 0.5, output)
18501850

1851-
inpt = make_segmentation_mask()
1852-
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
1853-
assert isinstance(output, features.SegmentationMask)
1854-
torch.testing.assert_close(inpt, output)
1855-
1856-
inpt = make_bounding_box(format="XYXY")
1857-
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
1858-
assert isinstance(output, features.BoundingBox)
1859-
torch.testing.assert_close(inpt, output)
1860-
18611851
inpt = make_image(color_space=features.ColorSpace.RGB)
1862-
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
1852+
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
18631853
assert isinstance(output, torch.Tensor)
18641854
torch.testing.assert_close(inpt - 0.5, output)
18651855

torchvision/prototype/transforms/_misc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import Any, Callable, Dict, List, Sequence, Type, Union
2+
from typing import Any, Callable, Dict, Sequence, Type, Union
33

44
import PIL.Image
55

@@ -10,6 +10,8 @@
1010
from torchvision.prototype.transforms._utils import query_bounding_box
1111
from torchvision.transforms.transforms import _setup_size
1212

13+
from ._utils import is_simple_tensor
14+
1315

1416
class Identity(Transform):
1517
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -91,10 +93,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9193

9294

9395
class Normalize(Transform):
94-
def __init__(self, mean: List[float], std: List[float]):
96+
_transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor)
97+
98+
def __init__(self, mean: Sequence[float], std: Sequence[float]):
9599
super().__init__()
96-
self.mean = mean
97-
self.std = std
100+
self.mean = list(mean)
101+
self.std = list(std)
98102

99103
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
100104
return F.normalize(inpt, mean=self.mean, std=self.std)

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
normalize_image_tensor = _FT.normalize
1515

1616

17-
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
18-
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
19-
return inpt
20-
elif isinstance(inpt, PIL.Image.Image):
21-
raise TypeError("Unsupported input type")
17+
def normalize(
18+
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
19+
) -> DType:
20+
if not isinstance(inpt, torch.Tensor):
21+
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
2222
else:
2323
# Image instance after normalization is not Image anymore due to unknown data range
2424
# Thus we return Tensor for input Image

0 commit comments

Comments
 (0)