From 4767fef93d69ed3c0de4773f171dee2f744360ef Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 Oct 2022 10:11:43 +0000 Subject: [PATCH 01/11] [proto] Reduce number of calls of __torch_function__ --- torchvision/prototype/features/_feature.py | 25 ++++++++++++++++++++++ torchvision/prototype/features/_image.py | 7 +++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 9c0cece15be..9501ec7851d 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -50,6 +50,9 @@ def new_like( requires_grad: Optional[bool] = None, **kwargs: Any, ) -> F: + # Quick fix: Feature -> Tensor => won't go to __torch_function__ + other = other.as_subclass(torch.Tensor) + return cls( data, dtype=dtype if dtype is not None else other.dtype, @@ -137,6 +140,28 @@ def _F(self) -> ModuleType: _Feature.__F = functional return _Feature.__F + # Add properties for common attributes like shape, dtype, device, ndim etc + # this way we return the result without passing into __torch_function__ + @property + def shape(self): + return self.as_subclass(torch.Tensor).shape + + @property + def ndim(self): + return self.as_subclass(torch.Tensor).ndim + + @property + def device(self): + return self.as_subclass(torch.Tensor).device + + @property + def dtype(self): + return self.as_subclass(torch.Tensor).dtype + + @property + def requires_grad(self): + return self.as_subclass(torch.Tensor).requires_grad + def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 21126c7f254..0a1fcf7eefd 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -97,9 +97,10 @@ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[overr def new_like( cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any ) -> Image: - return super().new_like( - other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs - ) + # Question: Is it safe to assume data to be a tensor ? + out = data.as_subclass(Image) + out.color_space = color_space if color_space is not None else other.color_space + return out @property def image_size(self) -> Tuple[int, int]: From 5ba3c0d58ff1d9acf498b60f47101651b602a655 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 5 Oct 2022 11:05:24 +0000 Subject: [PATCH 02/11] Use DisableTorchFunction and super --- torchvision/prototype/features/_feature.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 9501ec7851d..926a8bd50a3 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -144,23 +144,28 @@ def _F(self) -> ModuleType: # this way we return the result without passing into __torch_function__ @property def shape(self): - return self.as_subclass(torch.Tensor).shape + with DisableTorchFunction(): + return super().shape @property def ndim(self): - return self.as_subclass(torch.Tensor).ndim + with DisableTorchFunction(): + return super().ndim @property def device(self): - return self.as_subclass(torch.Tensor).device + with DisableTorchFunction(): + return super().device @property def dtype(self): - return self.as_subclass(torch.Tensor).dtype + with DisableTorchFunction(): + return super().dtype @property def requires_grad(self): - return self.as_subclass(torch.Tensor).requires_grad + with DisableTorchFunction(): + return super().requires_grad def horizontal_flip(self) -> _Feature: return self From 6435c095aaf48c8e157082f6d39be054eb0b32e1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 5 Oct 2022 11:10:28 +0000 Subject: [PATCH 03/11] Use self._tensor --- torchvision/prototype/features/_feature.py | 30 +++++++++------------- torchvision/prototype/features/_image.py | 4 +++ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 926a8bd50a3..746061f20a2 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -29,15 +29,14 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - return ( - torch.as_tensor( # type: ignore[return-value] - data, - dtype=dtype, # type: ignore[arg-type] - device=device, # type: ignore[arg-type] - ) - .as_subclass(cls) # type: ignore[arg-type] - .requires_grad_(requires_grad) + tensor = torch.as_tensor( # type: ignore[return-value] + data, + dtype=dtype, # type: ignore[arg-type] + device=device, # type: ignore[arg-type] ) + output = tensor.as_subclass(cls).requires_grad_(requires_grad) # type: ignore[arg-type] + output._tensor = tensor + return output @classmethod def new_like( @@ -144,28 +143,23 @@ def _F(self) -> ModuleType: # this way we return the result without passing into __torch_function__ @property def shape(self): - with DisableTorchFunction(): - return super().shape + return self._tensor.shape @property def ndim(self): - with DisableTorchFunction(): - return super().ndim + return self._tensor.ndim @property def device(self): - with DisableTorchFunction(): - return super().device + return self._tensor.device @property def dtype(self): - with DisableTorchFunction(): - return super().dtype + return self._tensor.dtype @property def requires_grad(self): - with DisableTorchFunction(): - return super().requires_grad + return self._tensor.requires_grad def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0a1fcf7eefd..eaa7a167271 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -100,7 +100,11 @@ def new_like( # Question: Is it safe to assume data to be a tensor ? out = data.as_subclass(Image) out.color_space = color_space if color_space is not None else other.color_space + out._tensor = data return out + # return super().new_like( + # other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs + # ) @property def image_size(self) -> Tuple[int, int]: From b7694e6287ffe9406114b1a4d4c2b5c6ddacb12d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 5 Oct 2022 16:58:20 +0000 Subject: [PATCH 04/11] Fixes mypy and color space handling --- torchvision/prototype/features/_feature.py | 29 ++++++++++--------- torchvision/prototype/features/_image.py | 33 ++++++++++++---------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 746061f20a2..d0c73e46cf8 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -6,6 +6,7 @@ import PIL.Image import torch from torch._C import DisableTorchFunction +from torch.types import _device, _dtype, _size from torchvision.transforms import InterpolationMode @@ -29,14 +30,14 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - tensor = torch.as_tensor( # type: ignore[return-value] + tensor = torch.as_tensor( data, dtype=dtype, # type: ignore[arg-type] device=device, # type: ignore[arg-type] ) output = tensor.as_subclass(cls).requires_grad_(requires_grad) # type: ignore[arg-type] - output._tensor = tensor - return output + output._tensor = tensor # type: ignore[attr-defined] + return output # type: ignore[return-value] @classmethod def new_like( @@ -50,7 +51,7 @@ def new_like( **kwargs: Any, ) -> F: # Quick fix: Feature -> Tensor => won't go to __torch_function__ - other = other.as_subclass(torch.Tensor) + other = other.as_subclass(torch.Tensor) # type: ignore[arg-type] return cls( data, @@ -142,24 +143,24 @@ def _F(self) -> ModuleType: # Add properties for common attributes like shape, dtype, device, ndim etc # this way we return the result without passing into __torch_function__ @property - def shape(self): - return self._tensor.shape + def shape(self) -> _size: # type: ignore[override] + return self._tensor.shape # type: ignore[attr-defined, no-any-return] @property - def ndim(self): - return self._tensor.ndim + def ndim(self) -> int: # type: ignore[override] + return self._tensor.ndim # type: ignore[attr-defined, no-any-return] @property - def device(self): - return self._tensor.device + def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] + return self._tensor.device # type: ignore[attr-defined, no-any-return] @property - def dtype(self): - return self._tensor.dtype + def dtype(self) -> _dtype: # type: ignore[override] + return self._tensor.dtype # type: ignore[attr-defined, no-any-return] @property - def requires_grad(self): - return self._tensor.requires_grad + def requires_grad(self) -> bool: # type: ignore[override] + return self._tensor.requires_grad # type: ignore[attr-defined, no-any-return] def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index eaa7a167271..9357d01ab6e 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -59,6 +59,18 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: return ColorSpace.OTHER +def _setup_color_space(color_space: Union[None, ColorSpace, str], shape: List[int]) -> ColorSpace: + if color_space is None: + color_space = ColorSpace.from_tensor_shape(shape) + if color_space == ColorSpace.OTHER: + warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") + return color_space + elif isinstance(color_space, str): + return ColorSpace.from_str(color_space.upper()) + + raise ValueError + + class Image(_Feature): color_space: ColorSpace @@ -78,15 +90,7 @@ def __new__( data = data.unsqueeze(0) image = super().__new__(cls, data, requires_grad=requires_grad) - if color_space is None: - color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type] - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - elif isinstance(color_space, str): - color_space = ColorSpace.from_str(color_space.upper()) - elif not isinstance(color_space, ColorSpace): - raise ValueError - image.color_space = color_space + image.color_space = _setup_color_space(color_space, list(image.shape)) return image @@ -98,13 +102,12 @@ def new_like( cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any ) -> Image: # Question: Is it safe to assume data to be a tensor ? - out = data.as_subclass(Image) - out.color_space = color_space if color_space is not None else other.color_space - out._tensor = data + out: Image = data.as_subclass(Image) + out.color_space = _setup_color_space( + color_space if color_space is not None else other.color_space, list(data.shape) + ) + out._tensor = data # type: ignore[attr-defined] return out - # return super().new_like( - # other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs - # ) @property def image_size(self) -> Tuple[int, int]: From 628826c9901075755834904b44469fbad13848f7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Oct 2022 12:06:51 +0000 Subject: [PATCH 05/11] revert Image.new_like --- torchvision/prototype/features/_image.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 9357d01ab6e..ffdff2694a6 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -60,7 +60,9 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: def _setup_color_space(color_space: Union[None, ColorSpace, str], shape: List[int]) -> ColorSpace: - if color_space is None: + if isinstance(color_space, ColorSpace): + return color_space + elif color_space is None: color_space = ColorSpace.from_tensor_shape(shape) if color_space == ColorSpace.OTHER: warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") @@ -68,7 +70,7 @@ def _setup_color_space(color_space: Union[None, ColorSpace, str], shape: List[in elif isinstance(color_space, str): return ColorSpace.from_str(color_space.upper()) - raise ValueError + raise ValueError(f"Unsupported color space '{color_space}'") class Image(_Feature): @@ -101,13 +103,9 @@ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[overr def new_like( cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any ) -> Image: - # Question: Is it safe to assume data to be a tensor ? - out: Image = data.as_subclass(Image) - out.color_space = _setup_color_space( - color_space if color_space is not None else other.color_space, list(data.shape) + return super().new_like( + other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs ) - out._tensor = data # type: ignore[attr-defined] - return out @property def image_size(self) -> Tuple[int, int]: From 345790be347cf19fe4b15ffc6c4f7c2b23f3b621 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 7 Oct 2022 10:00:39 +0000 Subject: [PATCH 06/11] WIP --- torchvision/prototype/features/_feature.py | 1 - torchvision/prototype/features/_image.py | 22 +++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index d0c73e46cf8..992831e855b 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -52,7 +52,6 @@ def new_like( ) -> F: # Quick fix: Feature -> Tensor => won't go to __torch_function__ other = other.as_subclass(torch.Tensor) # type: ignore[arg-type] - return cls( data, dtype=dtype if dtype is not None else other.dtype, diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index ffdff2694a6..88b173fd302 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -103,9 +103,25 @@ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[overr def new_like( cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any ) -> Image: - return super().new_like( - other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs - ) + # Question: Is it safe to assume data to be a tensor ? + out = data.as_subclass(Image) + out.color_space = color_space if color_space is not None else other.color_space + out._tensor = data # type: ignore[attr-defined] + return out + + @classmethod + def _wrap(tensor, color_space=None): + image = tensor.as_subclass(Image) # type: ignore[arg-type] + image.color_space = _setup_color_space(color_space, list(image.shape)) + return image + + # @classmethod + # def new_like( + # cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any + # ) -> Image: + # return super().new_like( + # other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs + # ) @property def image_size(self) -> Tuple[int, int]: From db0eef5fc794be9bb8180f7edd0596d92728e0e0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 10 Oct 2022 10:36:52 +0000 Subject: [PATCH 07/11] Perf opt with ref to tensor and properties --- torchvision/prototype/features/_bounding_box.py | 1 + torchvision/prototype/features/_encoded.py | 4 +++- torchvision/prototype/features/_feature.py | 8 ++++++-- torchvision/prototype/features/_image.py | 15 +-------------- torchvision/prototype/features/_label.py | 1 + torchvision/prototype/features/_mask.py | 4 +++- torchvision/prototype/features/_video.py | 7 ++++--- 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 7b69af5f9bb..4204f12aad2 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -24,6 +24,7 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: T bounding_box = tensor.as_subclass(cls) bounding_box.format = format bounding_box.image_size = image_size + bounding_box._tensor = tensor return bounding_box def __new__( diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 4b963986b4f..8344ea850ae 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -16,7 +16,9 @@ class EncodedData(_Feature): @classmethod def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) + output = tensor.as_subclass(cls) + output._tensor = tensor + return output def __new__( cls, diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index e4c7b88ed5b..0f342113b59 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -43,14 +43,18 @@ def __new__( requires_grad: bool = False, ) -> _Feature: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return tensor.as_subclass(_Feature) + output = tensor.as_subclass(_Feature) + output._tensor = tensor + return output @classmethod def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, # this method should be made abstract # raise NotImplementedError - return tensor.as_subclass(cls) + output = tensor.as_subclass(cls) + output._tensor = tensor + return output _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index c0c9326ec27..1936aae3f6c 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -59,20 +59,6 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: return ColorSpace.OTHER -def _setup_color_space(color_space: Union[None, ColorSpace, str], shape: List[int]) -> ColorSpace: - if isinstance(color_space, ColorSpace): - return color_space - elif color_space is None: - color_space = ColorSpace.from_tensor_shape(shape) - if color_space == ColorSpace.OTHER: - warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") - return color_space - elif isinstance(color_space, str): - return ColorSpace.from_str(color_space.upper()) - - raise ValueError(f"Unsupported color space '{color_space}'") - - class Image(_Feature): color_space: ColorSpace @@ -80,6 +66,7 @@ class Image(_Feature): def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: image = tensor.as_subclass(cls) image.color_space = color_space + image._tensor = tensor return image def __new__( diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index 9c2bcfc0fb1..7743ef4970a 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -18,6 +18,7 @@ class _LabelBase(_Feature): def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: label_base = tensor.as_subclass(cls) label_base.categories = categories + label_base._tensor = tensor return label_base def __new__( diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 7b49ce8e85e..fb316f5b054 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -11,7 +11,9 @@ class Mask(_Feature): @classmethod def _wrap(cls, tensor: torch.Tensor) -> Mask: - return tensor.as_subclass(cls) + output = tensor.as_subclass(cls) + output._tensor = tensor + return output def __new__( cls, diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index a58027243cf..dd663de8384 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -15,9 +15,10 @@ class Video(_Feature): @classmethod def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: - image = tensor.as_subclass(cls) - image.color_space = color_space - return image + video = tensor.as_subclass(cls) + video.color_space = color_space + video._tensor = tensor + return video def __new__( cls, From c8b3ac85c2dc55092b86c10f4732ba11407a90fd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 10 Oct 2022 12:35:28 +0000 Subject: [PATCH 08/11] Removed requires_grad property --- torchvision/prototype/features/_feature.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 0f342113b59..695a802154b 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -151,10 +151,6 @@ def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override def dtype(self) -> _dtype: # type: ignore[override] return self._tensor.dtype # type: ignore[attr-defined, no-any-return] - @property - def requires_grad(self) -> bool: # type: ignore[override] - return self._tensor.requires_grad # type: ignore[attr-defined, no-any-return] - def horizontal_flip(self) -> _Feature: return self From 38f8e21242830fed46ddf31287edb67c1abd124a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 14 Oct 2022 15:40:24 +0000 Subject: [PATCH 09/11] Use _tensor ref --- torchvision/prototype/features/_bounding_box.py | 1 + torchvision/prototype/features/_encoded.py | 4 +++- torchvision/prototype/features/_feature.py | 17 ++++++++--------- torchvision/prototype/features/_image.py | 1 + torchvision/prototype/features/_label.py | 1 + torchvision/prototype/features/_mask.py | 4 +++- torchvision/prototype/features/_video.py | 1 + 7 files changed, 18 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 18c607d4d16..db9af91c89f 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -24,6 +24,7 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: bounding_box = tensor.as_subclass(cls) bounding_box.format = format bounding_box.spatial_size = spatial_size + bounding_box._tensor = tensor return bounding_box def __new__( diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ffa347a3ef6..390263d3aea 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -16,7 +16,9 @@ class EncodedData(_Feature): @classmethod def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) + output = tensor.as_subclass(cls) + output._tensor = tensor + return output def __new__( cls, diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index ec6db649025..834fea426d4 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -44,6 +44,7 @@ def __new__( ) -> _Feature: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) output = tensor.as_subclass(_Feature) + output._tensor = tensor return output @classmethod @@ -108,7 +109,9 @@ def __torch_function__( # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, # will retain the input type. Thus, we need to unwrap here. if isinstance(output, cls): - return output.as_subclass(torch.Tensor) + tensor = output.as_subclass(torch.Tensor) + output._tensor = tensor + return tensor return output @@ -134,23 +137,19 @@ def _F(self) -> ModuleType: # this way we return the result without passing into __torch_function__ @property def shape(self) -> _size: # type: ignore[override] - with DisableTorchFunction(): - return super().shape + return self._tensor.shape @property def ndim(self) -> int: # type: ignore[override] - with DisableTorchFunction(): - return super().ndim + return self._tensor.ndim @property def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] - with DisableTorchFunction(): - return super().device + return self._tensor.device @property def dtype(self) -> _dtype: # type: ignore[override] - with DisableTorchFunction(): - return super().dtype + return self._tensor.dtype def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index e9128b94be0..224f1978611 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -64,6 +64,7 @@ class Image(_Feature): def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: image = tensor.as_subclass(cls) image.color_space = color_space + image._tensor = tensor return image def __new__( diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index 9c2bcfc0fb1..7743ef4970a 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -18,6 +18,7 @@ class _LabelBase(_Feature): def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: label_base = tensor.as_subclass(cls) label_base.categories = categories + label_base._tensor = tensor return label_base def __new__( diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 2da10195e80..35f9f67a5bd 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -11,7 +11,9 @@ class Mask(_Feature): @classmethod def _wrap(cls, tensor: torch.Tensor) -> Mask: - return tensor.as_subclass(cls) + output = tensor.as_subclass(cls) + output._tensor = tensor + return output def __new__( cls, diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index 26f97549ac5..54c20b2d8b4 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -17,6 +17,7 @@ class Video(_Feature): def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: video = tensor.as_subclass(cls) video.color_space = color_space + video._tensor = tensor return video def __new__( From 8441dbc247418aedc27bb5f015d0bacd89b479da Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 14 Oct 2022 15:59:46 +0000 Subject: [PATCH 10/11] Revert "Use _tensor ref" This reverts commit 38f8e21242830fed46ddf31287edb67c1abd124a. --- torchvision/prototype/features/_bounding_box.py | 1 - torchvision/prototype/features/_encoded.py | 4 +--- torchvision/prototype/features/_feature.py | 17 +++++++++-------- torchvision/prototype/features/_image.py | 1 - torchvision/prototype/features/_label.py | 1 - torchvision/prototype/features/_mask.py | 4 +--- torchvision/prototype/features/_video.py | 1 - 7 files changed, 11 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index db9af91c89f..18c607d4d16 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -24,7 +24,6 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: bounding_box = tensor.as_subclass(cls) bounding_box.format = format bounding_box.spatial_size = spatial_size - bounding_box._tensor = tensor return bounding_box def __new__( diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 390263d3aea..ffa347a3ef6 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -16,9 +16,7 @@ class EncodedData(_Feature): @classmethod def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - output = tensor.as_subclass(cls) - output._tensor = tensor - return output + return tensor.as_subclass(cls) def __new__( cls, diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 834fea426d4..ec6db649025 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -44,7 +44,6 @@ def __new__( ) -> _Feature: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) output = tensor.as_subclass(_Feature) - output._tensor = tensor return output @classmethod @@ -109,9 +108,7 @@ def __torch_function__( # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, # will retain the input type. Thus, we need to unwrap here. if isinstance(output, cls): - tensor = output.as_subclass(torch.Tensor) - output._tensor = tensor - return tensor + return output.as_subclass(torch.Tensor) return output @@ -137,19 +134,23 @@ def _F(self) -> ModuleType: # this way we return the result without passing into __torch_function__ @property def shape(self) -> _size: # type: ignore[override] - return self._tensor.shape + with DisableTorchFunction(): + return super().shape @property def ndim(self) -> int: # type: ignore[override] - return self._tensor.ndim + with DisableTorchFunction(): + return super().ndim @property def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] - return self._tensor.device + with DisableTorchFunction(): + return super().device @property def dtype(self) -> _dtype: # type: ignore[override] - return self._tensor.dtype + with DisableTorchFunction(): + return super().dtype def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 224f1978611..e9128b94be0 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -64,7 +64,6 @@ class Image(_Feature): def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: image = tensor.as_subclass(cls) image.color_space = color_space - image._tensor = tensor return image def __new__( diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index 7743ef4970a..9c2bcfc0fb1 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -18,7 +18,6 @@ class _LabelBase(_Feature): def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: label_base = tensor.as_subclass(cls) label_base.categories = categories - label_base._tensor = tensor return label_base def __new__( diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 35f9f67a5bd..2da10195e80 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -11,9 +11,7 @@ class Mask(_Feature): @classmethod def _wrap(cls, tensor: torch.Tensor) -> Mask: - output = tensor.as_subclass(cls) - output._tensor = tensor - return output + return tensor.as_subclass(cls) def __new__( cls, diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index 54c20b2d8b4..26f97549ac5 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -17,7 +17,6 @@ class Video(_Feature): def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: video = tensor.as_subclass(cls) video.color_space = color_space - video._tensor = tensor return video def __new__( From 7e8eb46c75a32c35450dc6705c78045526cb6402 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 17 Oct 2022 09:31:02 +0200 Subject: [PATCH 11/11] Update torchvision/prototype/features/_feature.py Co-authored-by: Philip Meier --- torchvision/prototype/features/_feature.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index ec6db649025..1cc2d8d4bb7 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -43,8 +43,7 @@ def __new__( requires_grad: bool = False, ) -> _Feature: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - output = tensor.as_subclass(_Feature) - return output + return tensor.as_subclass(_Feature) @classmethod def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: