diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index a56441f2967..1cc2d8d4bb7 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -6,6 +6,7 @@ import PIL.Image import torch from torch._C import DisableTorchFunction +from torch.types import _device, _dtype, _size from torchvision.transforms import InterpolationMode @@ -128,6 +129,28 @@ def _F(self) -> ModuleType: _Feature.__F = functional return _Feature.__F + # Add properties for common attributes like shape, dtype, device, ndim etc + # this way we return the result without passing into __torch_function__ + @property + def shape(self) -> _size: # type: ignore[override] + with DisableTorchFunction(): + return super().shape + + @property + def ndim(self) -> int: # type: ignore[override] + with DisableTorchFunction(): + return super().ndim + + @property + def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] + with DisableTorchFunction(): + return super().device + + @property + def dtype(self) -> _dtype: # type: ignore[override] + with DisableTorchFunction(): + return super().dtype + def horizontal_flip(self) -> _Feature: return self diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index 9dfff7f964e..26f97549ac5 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -15,9 +15,9 @@ class Video(_Feature): @classmethod def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: - image = tensor.as_subclass(cls) - image.color_space = color_space - return image + video = tensor.as_subclass(cls) + video.color_space = color_space + return video def __new__( cls,