Skip to content

Commit bfc8510

Browse files
authored
revamp prototype features (#5283)
1 parent 45c15f5 commit bfc8510

File tree

11 files changed

+320
-469
lines changed

11 files changed

+320
-469
lines changed

test/test_prototype_features.py

Lines changed: 0 additions & 185 deletions
This file was deleted.

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
image_buffer_from_array,
2727
Decompressor,
2828
INFINITE_BUFFER_SIZE,
29-
fromfile,
3029
hint_sharding,
3130
hint_shuffling,
3231
)
3332
from torchvision.prototype.features import Image, Label
33+
from torchvision.prototype.utils._internal import fromfile
3434

3535
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
3636

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
import gzip
44
import io
55
import lzma
6-
import mmap
76
import os
87
import os.path
98
import pathlib
109
import pickle
11-
import platform
1210
from typing import BinaryIO
1311
from typing import (
1412
Sequence,
@@ -32,6 +30,7 @@
3230
import torch.utils.data
3331
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
3432
from torchdata.datapipes.utils import StreamWrapper
33+
from torchvision.prototype.utils._internal import fromfile
3534

3635

3736
__all__ = [
@@ -46,7 +45,6 @@
4645
"path_accessor",
4746
"path_comparator",
4847
"Decompressor",
49-
"fromfile",
5048
"read_flo",
5149
"hint_sharding",
5250
]
@@ -267,69 +265,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st
267265
return dp
268266

269267

270-
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
271-
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
272-
return bytearray(file.read(-1 if count == -1 else count * item_size))
273-
274-
275-
def fromfile(
276-
file: BinaryIO,
277-
*,
278-
dtype: torch.dtype,
279-
byte_order: str,
280-
count: int = -1,
281-
) -> torch.Tensor:
282-
"""Construct a tensor from a binary file.
283-
284-
.. note::
285-
286-
This function is similar to :func:`numpy.fromfile` with two notable differences:
287-
288-
1. This function only accepts an open binary file, but not a path to it.
289-
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
290-
concept.
291-
292-
.. note::
293-
294-
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
295-
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
296-
297-
Args:
298-
file (IO): Open binary file.
299-
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
300-
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
301-
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
302-
"""
303-
byte_order = "<" if byte_order == "little" else ">"
304-
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
305-
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
306-
np_dtype = byte_order + char + str(item_size)
307-
308-
buffer: Union[memoryview, bytearray]
309-
if platform.system() != "Windows":
310-
# PyTorch does not support tensors with underlying read-only memory. In case
311-
# - the file has a .fileno(),
312-
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
313-
# - the file is seekable
314-
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
315-
# to a mutable location afterwards.
316-
try:
317-
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
318-
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
319-
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
320-
except (PermissionError, io.UnsupportedOperation):
321-
buffer = _read_mutable_buffer_fallback(file, count, item_size)
322-
else:
323-
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
324-
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
325-
buffer = _read_mutable_buffer_fallback(file, count, item_size)
326-
327-
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
328-
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
329-
# successive .astype() call.
330-
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
331-
332-
333268
def read_flo(file: BinaryIO) -> torch.Tensor:
334269
if file.read(4) != b"PIEH":
335270
raise ValueError("Magic number incorrect. Invalid .flo file")
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from ._bounding_box import BoundingBoxFormat, BoundingBox
2-
from ._feature import Feature, DEFAULT
3-
from ._image import Image, ColorSpace
4-
from ._label import Label
1+
from ._bounding_box import BoundingBox, BoundingBoxFormat
2+
from ._encoded import EncodedData, EncodedImage, EncodedVideo
3+
from ._feature import Feature
4+
from ._image import ColorSpace, Image
5+
from ._label import Label, OneHotLabel
6+
from ._segmentation_mask import SegmentationMask

0 commit comments

Comments
 (0)