Skip to content

[proto] Use the proper _transformed_types in all Transforms and eliminate unnecessary dispatching #6494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,10 @@ def test__transform(self, inpt_type, mocker):
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
if inpt_type in (features.BoundingBox, features.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
fn.assert_called_once_with(inpt)


class TestToImagePIL:
Expand All @@ -1059,7 +1059,7 @@ def test__transform(self, inpt_type, mocker):
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
Expand Down
19 changes: 5 additions & 14 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,39 +1867,30 @@ def test_midlevel_normalize_output_type():
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32)),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("copy", [True, False])
def test_to_image_tensor(inpt, copy):
output = F.to_image_tensor(inpt, copy=copy)
def test_to_image_tensor(inpt):
output = F.to_image_tensor(inpt)
assert isinstance(output, torch.Tensor)

assert np.asarray(inpt).sum() == output.sum().item()

if isinstance(inpt, PIL.Image.Image) and not copy:
if isinstance(inpt, PIL.Image.Image):
# we can't check this option
# as PIL -> numpy is always copying
return

if isinstance(inpt, PIL.Image.Image):
inpt.putpixel((0, 0), 11)
else:
inpt[0, 0, 0] = 11
if copy:
assert output[0, 0, 0] != 11
else:
assert output[0, 0, 0] == 11
inpt[0, 0, 0] = 11
assert output[0, 0, 0] == 11


@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
127 * np.ones((32, 32, 3), dtype="uint8"),
PIL.Image.new("RGB", (32, 32), 122),
],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
Expand Down
32 changes: 11 additions & 21 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
Expand All @@ -15,9 +16,7 @@


class ToTensor(Transform):

# Updated transformed types for ToTensor
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (PIL.Image.Image, np.ndarray)

def __init__(self) -> None:
warnings.warn(
Expand All @@ -26,32 +25,26 @@ def __init__(self) -> None:
)
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)


class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)

def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, PIL.Image.Image):
return _F.pil_to_tensor(inpt)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.pil_to_tensor(inpt)


class ToPILImage(Transform):

# Updated transformed types for ToPILImage
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)

def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
Expand All @@ -61,11 +54,8 @@ def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
return _F.to_pil_image(inpt, mode=self.mode)


class Grayscale(Transform):
Expand Down
21 changes: 8 additions & 13 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,32 @@


class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)

def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
super().__init__()
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
self.format = format

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.BoundingBox):
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
else:
return inpt
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])


class ConvertImageDtype(Transform):
_transformed_types = (is_simple_tensor, features.Image)

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = convert_image_dtype(inpt, dtype=self.dtype)
return features.Image.new_like(inpt, output, dtype=self.dtype)
elif is_simple_tensor(inpt):
return convert_image_dtype(inpt, dtype=self.dtype)
else:
return inpt
output = convert_image_dtype(inpt, dtype=self.dtype)
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)


class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)

def __init__(
Expand Down
39 changes: 12 additions & 27 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@


class DecodeImage(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.EncodedImage):
output = F.decode_image_with_pil(inpt)
return features.Image(output)
else:
return inpt
_transformed_types = (features.EncodedImage,)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.decode_image_with_pil(inpt)
return features.Image(output)


class LabelToOneHot(Transform):
Expand All @@ -41,33 +40,19 @@ def extra_repr(self) -> str:


class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)

# Updated transformed types for ToImageTensor
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)

def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
output = F.to_image_tensor(inpt)
return features.Image(output)


class ToImagePIL(Transform):

# Updated transformed types for ToImagePIL
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)

def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, mode=self.mode)
else:
return inpt
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
23 changes: 4 additions & 19 deletions torchvision/prototype/transforms/functional/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest.mock
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union

import numpy as np
import PIL.Image
Expand All @@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]


def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)

if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image
return torch.from_numpy(image)

return _F.pil_to_tensor(image)


def to_image_pil(
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if mode != image.mode:
return image.convert(mode)
else:
return image

return _F.to_pil_image(image, mode=mode)
to_image_pil = _F.to_pil_image