diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py deleted file mode 100644 index 147243286d4..00000000000 --- a/test/test_prototype_features.py +++ /dev/null @@ -1,185 +0,0 @@ -import functools -import itertools - -import pytest -import torch -from torch.testing import make_tensor as _make_tensor, assert_close -from torchvision.prototype import features -from torchvision.prototype.utils._internal import sequence_to_str - - -make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) - - -def make_image(**kwargs): - data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist())) - return features.Image(data, **kwargs) - - -def make_bounding_box(*, format="xyxy", image_size=(10, 10)): - if isinstance(format, str): - format = features.BoundingBoxFormat[format] - - height, width = image_size - - if format == features.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, ()) - y1 = torch.randint(0, height // 2, ()) - x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1 - y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1 - parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, ()) - y = torch.randint(0, height // 2, ()) - w = torch.randint(1, width - int(x), ()) - h = torch.randint(1, height - int(y), ()) - parts = (x, y, w, h) - elif format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) - w = torch.randint(1, min(int(cx), width - int(cx)), ()) - h = torch.randint(1, min(int(cy), height - int(cy)), ()) - parts = (cx, cy, w, h) - else: # format == features.BoundingBoxFormat._SENTINEL: - parts = make_tensor((4,)).unbind() - - return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size) - - -MAKE_DATA_MAP = { - features.Image: make_image, - features.BoundingBox: make_bounding_box, -} - - -def make_feature(feature_type, **meta_data): - maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data)) - return maker(**meta_data) - - -class TestCommon: - FEATURE_TYPES, NON_DEFAULT_META_DATA = zip( - *( - (features.Image, dict(color_space=features.ColorSpace._SENTINEL)), - (features.Label, dict(category="category")), - (features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))), - ) - ) - feature_types = pytest.mark.parametrize( - "feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__ - ) - features = pytest.mark.parametrize( - "feature", - [ - pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__) - for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA) - ], - ) - - def test_consistency(self): - builtin_feature_types = { - name - for name, feature_type in features.__dict__.items() - if not name.startswith("_") - and isinstance(feature_type, type) - and issubclass(feature_type, features.Feature) - and feature_type is not features.Feature - } - untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES} - if untested_feature_types: - raise AssertionError( - f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} " - f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. " - f"Please add it/them to `TestCommon.FEATURE_TYPES`." - ) - - @features - def test_meta_data_attribute_access(self, feature): - for name, value in feature._meta_data.items(): - assert getattr(feature, name) == feature._meta_data[name] - - @feature_types - def test_torch_function(self, feature_type): - input = make_feature(feature_type) - # This can be any Tensor operation besides clone - output = input + 1 - - assert type(output) is torch.Tensor - assert_close(output, input + 1) - - @feature_types - def test_clone(self, feature_type): - input = make_feature(feature_type) - output = input.clone() - - assert type(output) is feature_type - assert_close(output, input) - assert output._meta_data == input._meta_data - - @features - def test_serialization(self, tmpdir, feature): - file = tmpdir / "test_serialization.pt" - - torch.save(feature, str(file)) - loaded_feature = torch.load(str(file)) - - assert isinstance(loaded_feature, type(feature)) - assert_close(loaded_feature, feature) - assert loaded_feature._meta_data == feature._meta_data - - @features - def test_repr(self, feature): - assert type(feature).__name__ in repr(feature) - - -class TestBoundingBox: - @pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2)) - def test_cycle_consistency(self, format, intermediate_format): - input = make_bounding_box(format=format) - output = input.convert(intermediate_format).convert(format) - assert_close(input, output) - - -# For now, tensor subclasses with additional meta data do not work with torchscript. -# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037. -@pytest.mark.xfail -class TestJit: - def test_bounding_box(self): - def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: - old_height, old_width = input.image_size - new_height, new_width = size - - height_scale = new_height / old_height - width_scale = new_width / old_width - - old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() - - new_x1 = old_x1 * width_scale - new_y1 = old_y1 * height_scale - - new_x2 = old_x2 * width_scale - new_y2 = old_y2 * height_scale - - return features.BoundingBox.from_parts( - new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist()) - ) - - def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox: - x, y, w, h = input.convert("xywh").to_parts() - x = input.image_size[1] - (x + w) - return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") - - def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: - return horizontal_flip(resize(input, size)).convert("xyxy") - - image_size = (8, 6) - input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size) - size = torch.tensor((4, 12)) - expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size) - - actual_eager = compose(input, size) - assert_close(actual_eager, expected) - - sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5))) - actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size) - assert_close(actual_jit, expected) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 0d7fe36a3fd..ee9b2a65b51 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -26,11 +26,11 @@ image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, - fromfile, hint_sharding, hint_shuffling, ) from torchvision.prototype.features import Image, Label +from torchvision.prototype.utils._internal import fromfile __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index e21e8ffd25f..7106ea44a44 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -3,12 +3,10 @@ import gzip import io import lzma -import mmap import os import os.path import pathlib import pickle -import platform from typing import BinaryIO from typing import ( Sequence, @@ -32,6 +30,7 @@ import torch.utils.data from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler from torchdata.datapipes.utils import StreamWrapper +from torchvision.prototype.utils._internal import fromfile __all__ = [ @@ -46,7 +45,6 @@ "path_accessor", "path_comparator", "Decompressor", - "fromfile", "read_flo", "hint_sharding", ] @@ -267,69 +265,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st return dp -def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: - # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable - return bytearray(file.read(-1 if count == -1 else count * item_size)) - - -def fromfile( - file: BinaryIO, - *, - dtype: torch.dtype, - byte_order: str, - count: int = -1, -) -> torch.Tensor: - """Construct a tensor from a binary file. - - .. note:: - - This function is similar to :func:`numpy.fromfile` with two notable differences: - - 1. This function only accepts an open binary file, but not a path to it. - 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that - concept. - - .. note:: - - If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as - long as the file is still open, inplace operations on the returned tensor will reflect back to the file. - - Args: - file (IO): Open binary file. - dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. - byte_order (str): Byte order of the data. Can be "little" or "big" endian. - count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. - """ - byte_order = "<" if byte_order == "little" else ">" - char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") - item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - np_dtype = byte_order + char + str(item_size) - - buffer: Union[memoryview, bytearray] - if platform.system() != "Windows": - # PyTorch does not support tensors with underlying read-only memory. In case - # - the file has a .fileno(), - # - the file was opened for updating, i.e. 'r+b' or 'w+b', - # - the file is seekable - # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it - # to a mutable location afterwards. - try: - buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] - # Reading from the memoryview does not advance the file cursor, so we have to do it manually. - file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (PermissionError, io.UnsupportedOperation): - buffer = _read_mutable_buffer_fallback(file, count, item_size) - else: - # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state - # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. - buffer = _read_mutable_buffer_fallback(file, count, item_size) - - # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we - # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the - # successive .astype() call. - return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) - - def read_flo(file: BinaryIO) -> torch.Tensor: if file.read(4) != b"PIEH": raise ValueError("Magic number incorrect. Invalid .flo file") diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 4d77d3a5ce3..dd9982a04db 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,4 +1,6 @@ -from ._bounding_box import BoundingBoxFormat, BoundingBox -from ._feature import Feature, DEFAULT -from ._image import Image, ColorSpace -from ._label import Label +from ._bounding_box import BoundingBox, BoundingBoxFormat +from ._encoded import EncodedData, EncodedImage, EncodedVideo +from ._feature import Feature +from ._image import ColorSpace, Image +from ._label import Label, OneHotLabel +from ._segmentation_mask import SegmentationMask diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 64ba449ae76..2d0685c2088 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,83 +1,17 @@ -import enum -import functools -from typing import Callable, Union, Tuple, Dict, Any, Optional, cast +from typing import Any, Tuple, Union, Optional import torch from torchvision.prototype.utils._internal import StrEnum -from ._feature import Feature, DEFAULT +from ._feature import Feature class BoundingBoxFormat(StrEnum): # this is just for test purposes _SENTINEL = -1 - XYXY = enum.auto() - XYWH = enum.auto() - CXCYWH = enum.auto() - - -def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return input.unbind(dim=-1) # type: ignore[return-value] - - -def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor: - return torch.stack((a, b, c, d), dim=-1) - - -def format_converter_wrapper( - part_converter: Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - ] -): - def wrapper(input: torch.Tensor) -> torch.Tensor: - return from_parts(*part_converter(*to_parts(input))) - - return wrapper - - -@format_converter_wrapper -def xywh_to_xyxy( - x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x1 = x - y1 = y - x2 = x + w - y2 = y + h - return x1, y1, x2, y2 - - -@format_converter_wrapper -def xyxy_to_xywh( - x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x = x1 - y = y1 - w = x2 - x1 - h = y2 - y1 - return x, y, w, h - - -@format_converter_wrapper -def cxcywh_to_xyxy( - cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x1 = cx - 0.5 * w - y1 = cy - 0.5 * h - x2 = cx + 0.5 * w - y2 = cy + 0.5 * h - return x1, y1, x2, y2 - - -@format_converter_wrapper -def xyxy_to_cxcywh( - x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - return cx, cy, w, h + XYXY = StrEnum.auto() + XYWH = StrEnum.auto() + CXCYWH = StrEnum.auto() class BoundingBox(Feature): @@ -85,71 +19,20 @@ class BoundingBox(Feature): format: BoundingBoxFormat image_size: Tuple[int, int] - @classmethod - def _parse_meta_data( - cls, - format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] - image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: - if isinstance(format, str): - format = BoundingBoxFormat[format] - format_fallback = BoundingBoxFormat.XYXY - return dict( - format=(format, format_fallback), - image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)), - ) - - _TO_XYXY_MAP = { - BoundingBoxFormat.XYWH: xywh_to_xyxy, - BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy, - } - _FROM_XYXY_MAP = { - BoundingBoxFormat.XYWH: xyxy_to_xywh, - BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh, - } - - @classmethod - def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]: - if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH): - if format != BoundingBoxFormat.XYXY: - data = cls._TO_XYXY_MAP[format](data) - data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data) - *_, w, h = to_parts(data) - if data.dtype.is_floating_point: - w = w.ceil() - h = h.ceil() - return int(h.max()), int(w.max()) - - @classmethod - def from_parts( + def __new__( cls, - a, - b, - c, - d, + data: Any, *, - like: Optional["BoundingBox"] = None, - format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] - image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] - ) -> "BoundingBox": - return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format) - - def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return to_parts(self) + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + format: Union[BoundingBoxFormat, str], + image_size: Tuple[int, int], + ): + bounding_box = super().__new__(cls, data, dtype=dtype, device=device) - def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": if isinstance(format, str): format = BoundingBoxFormat[format] - if format == self.format: - return cast(BoundingBox, self.clone()) - - data = self - - if self.format != BoundingBoxFormat.XYXY: - data = self._TO_XYXY_MAP[self.format](data) - - if format != BoundingBoxFormat.XYXY: - data = self._FROM_XYXY_MAP[format](data) + bounding_box._metadata.update(dict(format=format, image_size=image_size)) - return BoundingBox(data, like=self, format=format) + return bounding_box diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py new file mode 100644 index 00000000000..9160b5e36e1 --- /dev/null +++ b/torchvision/prototype/features/_encoded.py @@ -0,0 +1,42 @@ +import os +import sys +from typing import BinaryIO, Tuple, Type, TypeVar, Union + +import PIL.Image +import torch +from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer + +from ._feature import Feature + +D = TypeVar("D", bound="EncodedData") + + +class EncodedData(Feature): + @classmethod + def _to_tensor(cls, data, *, dtype, device): + # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? + return super()._to_tensor(data, dtype=dtype, device=device) + + @classmethod + def from_file(cls: Type[D], file: BinaryIO) -> D: + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) + + @classmethod + def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D: + with open(path, "rb") as file: + return cls.from_file(file) + + +class EncodedImage(EncodedData): + # TODO: Use @functools.cached_property if we can depend on Python 3.8 + @property + def image_size(self) -> Tuple[int, int]: + if not hasattr(self, "_image_size"): + with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: + self._image_size = image.height, image.width + + return self._image_size + + +class EncodedVideo(EncodedData): + pass diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index cd52f1f80ad..085b39204dd 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,22 +1,18 @@ -from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence +from typing import Any, Callable, cast, Dict, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar import torch from torch._C import _TensorBase, DisableTorchFunction -from torchvision.prototype.utils._internal import add_suggestion F = TypeVar("F", bound="Feature") -DEFAULT = object() - - class Feature(torch.Tensor): _META_ATTRS: Set[str] = set() - _meta_data: Dict[str, Any] + _metadata: Dict[str, Any] def __init_subclass__(cls): - # In order to help static type checkers, we require subclasses of `Feature` add the meta data attributes + # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes # as static class annotations: # # >>> class Foo(Feature): @@ -38,63 +34,28 @@ def __init_subclass__(cls): meta_attrs.update(super_cls._META_ATTRS) cls._META_ATTRS = meta_attrs - for attr in meta_attrs: - setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr])) - - def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): - unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS - if unknown_meta_attrs: - unknown_meta_attr = sorted(unknown_meta_attrs)[0] - raise TypeError( - add_suggestion( - f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.", - word=unknown_meta_attr, - possibilities=sorted(cls._META_ATTRS), - ) - ) - - if like is not None: - dtype = dtype or like.dtype - device = device or like.device - data = cls._to_tensor(data, dtype=dtype, device=device) - requires_grad = False - self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad) - - meta_data = dict() - for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items(): - if explicit is not DEFAULT: - value = explicit - elif like is not None: - value = getattr(like, attr) - else: - value = fallback(data) if callable(fallback) else fallback - meta_data[attr] = value - self._meta_data = meta_data - - return self + for name in meta_attrs: + setattr(cls, name, property(lambda self, name=name: self._metadata[name])) + + def __new__(cls, data, *, dtype=None, device=None): + feature = torch.Tensor._make_subclass( + cast(_TensorBase, cls), + cls._to_tensor(data, dtype=dtype, device=device), + # requires_grad + False, + ) + feature._metadata = dict() + return feature @classmethod - def _to_tensor(cls, data, *, dtype, device): + def _to_tensor(self, data: Any, *, dtype, device) -> torch.Tensor: return torch.as_tensor(data, dtype=dtype, device=device) @classmethod - def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]: - return dict() - - @classmethod - def __torch_function__( - cls, - func: Callable[..., torch.Tensor], - types: Tuple[Type[torch.Tensor], ...], - args: Sequence[Any] = (), - kwargs: Optional[Mapping[str, Any]] = None, - ) -> torch.Tensor: - with DisableTorchFunction(): - output = func(*args, **(kwargs or dict())) - if func is not torch.Tensor.clone: - return output - - return cls(output, like=args[0]) + def new_like(cls, other, data, *, dtype=None, device=None, **metadata): + for name in cls._META_ATTRS: + metadata.setdefault(name, getattr(other, name)) + return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) def __repr__(self): return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 3d0b3d0c0af..93a9b517235 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,9 +1,12 @@ -from typing import Dict, Any, Union, Tuple +import warnings +from typing import Any, Optional, Union, Tuple, cast import torch from torchvision.prototype.utils._internal import StrEnum +from torchvision.transforms.functional import to_pil_image +from torchvision.utils import make_grid -from ._feature import Feature, DEFAULT +from ._feature import Feature class ColorSpace(StrEnum): @@ -18,23 +21,43 @@ class Image(Feature): color_spaces = ColorSpace color_space: ColorSpace + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + color_space: Optional[Union[ColorSpace, str]] = None, + ): + image = super().__new__(cls, data, dtype=dtype, device=device) + + if color_space is None: + color_space = cls.guess_color_space(image) + if color_space == ColorSpace.OTHER: + warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") + elif isinstance(color_space, str): + color_space = ColorSpace[color_space] + + image._metadata.update(dict(color_space=color_space)) + + return image + @classmethod def _to_tensor(cls, data, *, dtype, device): - tensor = torch.as_tensor(data, dtype=dtype, device=device) - if tensor.ndim == 2: + tensor = super()._to_tensor(data, dtype=dtype, device=device) + if tensor.ndim < 2: + raise ValueError + elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - elif tensor.ndim != 3: - raise ValueError("Only single images with 2 or 3 dimensions are allowed.") return tensor - @classmethod - def _parse_meta_data( - cls, - color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: - if isinstance(color_space, str): - color_space = ColorSpace[color_space] - return dict(color_space=(color_space, cls.guess_color_space)) + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], self.shape[-2:]) + + @property + def num_channels(self) -> int: + return self.shape[-3] @staticmethod def guess_color_space(data: torch.Tensor) -> ColorSpace: @@ -50,3 +73,6 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: return ColorSpace.RGB else: return ColorSpace.OTHER + + def show(self) -> None: + to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index ebdc6bbbc26..3ce1da647e7 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,14 +1,58 @@ -from typing import Dict, Any, Optional, Tuple +from typing import Any, Optional, Sequence -from ._feature import Feature, DEFAULT +import torch +from torchvision.prototype.utils._internal import apply_recursively + +from ._feature import Feature class Label(Feature): - category: Optional[str] + categories: Optional[Sequence[str]] + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + like: Optional["Label"] = None, + categories: Optional[Sequence[str]] = None, + ): + label = super().__new__(cls, data, dtype=dtype, device=device) + + label._metadata.update(dict(categories=categories)) + + return label @classmethod - def _parse_meta_data( + def from_category(cls, category: str, *, categories: Sequence[str]): + categories = list(categories) + return cls(categories.index(category), categories=categories) + + def to_categories(self): + if not self.categories: + raise RuntimeError() + + return apply_recursively(lambda idx: self.categories[idx], self.tolist()) + + +class OneHotLabel(Feature): + categories: Optional[Sequence[str]] + + def __new__( cls, - category: Optional[str] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: - return dict(category=(category, None)) + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + like: Optional["Label"] = None, + categories: Optional[Sequence[str]] = None, + ): + one_hot_label = super().__new__(cls, data, dtype=dtype, device=device) + + if categories is not None and len(categories) != one_hot_label.shape[-1]: + raise ValueError() + + one_hot_label._metadata.update(dict(categories=categories)) + + return one_hot_label diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py new file mode 100644 index 00000000000..d9d8354f022 --- /dev/null +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -0,0 +1,5 @@ +from ._feature import Feature + + +class SegmentationMask(Feature): + pass diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 9468dcf08a9..fe75c19eb75 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -3,11 +3,33 @@ import enum import functools import inspect +import io +import mmap import os import os.path +import platform import textwrap import warnings -from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast +from typing import ( + Any, + BinaryIO, + Callable, + cast, + Collection, + Iterable, + Iterator, + Mapping, + NoReturn, + Sequence, + Tuple, + TypeVar, + Union, + List, + Dict, +) + +import numpy as np +import torch __all__ = [ "StrEnum", @@ -17,10 +39,15 @@ "make_repr", "FrozenBunch", "kwonly_to_pos_or_kw", + "fromfile", + "ReadOnlyTensorBuffer", + "apply_recursively", ] class StrEnumMeta(enum.EnumMeta): + auto = enum.auto + def __getitem__(self, item): return super().__getitem__(item.upper() if isinstance(item, str) else item) @@ -186,3 +213,114 @@ def wrapper(*args: Any, **kwargs: Any) -> D: return fn(*args, **kwargs) return wrapper + + +def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: + # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable + return bytearray(file.read(-1 if count == -1 else count * item_size)) + + +def fromfile( + file: BinaryIO, + *, + dtype: torch.dtype, + byte_order: str, + count: int = -1, +) -> torch.Tensor: + """Construct a tensor from a binary file. + .. note:: + This function is similar to :func:`numpy.fromfile` with two notable differences: + 1. This function only accepts an open binary file, but not a path to it. + 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that + concept. + .. note:: + If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as + long as the file is still open, inplace operations on the returned tensor will reflect back to the file. + Args: + file (IO): Open binary file. + dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. + byte_order (str): Byte order of the data. Can be "little" or "big" endian. + count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. + """ + byte_order = "<" if byte_order == "little" else ">" + char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") + item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + np_dtype = byte_order + char + str(item_size) + + buffer: Union[memoryview, bytearray] + if platform.system() != "Windows": + # PyTorch does not support tensors with underlying read-only memory. In case + # - the file has a .fileno(), + # - the file was opened for updating, i.e. 'r+b' or 'w+b', + # - the file is seekable + # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it + # to a mutable location afterwards. + try: + buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] + # Reading from the memoryview does not advance the file cursor, so we have to do it manually. + file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) + except (AttributeError, PermissionError, io.UnsupportedOperation): + buffer = _read_mutable_buffer_fallback(file, count, item_size) + else: + # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state + # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. + buffer = _read_mutable_buffer_fallback(file, count, item_size) + + # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we + # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the + # successive .astype() call. + return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) + + +class ReadOnlyTensorBuffer: + def __init__(self, tensor: torch.Tensor) -> None: + self._memory = memoryview(tensor.numpy()) + self._cursor: int = 0 + + def tell(self) -> int: + return self._cursor + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + self._cursor = offset + elif whence == io.SEEK_CUR: + self._cursor += offset + pass + elif whence == io.SEEK_END: + self._cursor = len(self._memory) + offset + else: + raise ValueError( + f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, " + f"but got {repr(whence)} instead" + ) + return self.tell() + + def read(self, size: int = -1) -> bytes: + cursor = self.tell() + offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) + return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() + + +def apply_recursively(fn: Callable, obj: Any) -> Any: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + sequence: List[Any] = [] + for item in obj: + result = apply_recursively(fn, item) + if isinstance(result, collections.abc.Sequence) and hasattr(result, "__inline__"): + sequence.extend(result) + else: + sequence.append(result) + return sequence + elif isinstance(obj, collections.abc.Mapping): + mapping: Dict[Any, Any] = {} + for name, item in obj.items(): + result = apply_recursively(fn, item) + if isinstance(result, collections.abc.Mapping) and hasattr(result, "__inline__"): + mapping.update(result) + else: + mapping[name] = result + return mapping + else: + return fn(obj)