Skip to content

Commit 8dcb5b8

Browse files
authored
add prototype utilities to read arbitrary numeric binary files (#4882)
* add FloReader datapipe * add NumericBinaryReader * revert unrelated change * cleanup * cleanup * add comment for byte reversal * use numpy after all * appease mypy * use .astype() with copy=False * add docstring and cleanuo * reuse current _read_flo and revert MNIST changes * cleanup * revert demonstration * refactor * cleanup * add support for mutable memory * add test * add comments * catch more exceptions * fix mypy * fix variable names * hardcode flow sizes in test * add fix dtype docstring * expand comment on different reading modes * add comment about files in update mode * add tests for fromfile * cleanup * cleanup
1 parent fa1aa52 commit 8dcb5b8

File tree

3 files changed

+129
-32
lines changed

3 files changed

+129
-32
lines changed

test/test_prototype_datasets_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from datasets_utils import make_fake_flo_file
7+
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
8+
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
9+
10+
11+
@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning")
12+
@pytest.mark.parametrize(
13+
("np_dtype", "torch_dtype", "byte_order"),
14+
[
15+
(">f4", torch.float32, "big"),
16+
("<f8", torch.float64, "little"),
17+
("<i4", torch.int32, "little"),
18+
(">i8", torch.int64, "big"),
19+
("|u1", torch.uint8, sys.byteorder),
20+
],
21+
)
22+
@pytest.mark.parametrize("count", (-1, 2))
23+
@pytest.mark.parametrize("mode", ("rb", "r+b"))
24+
def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode):
25+
path = tmpdir / "data.bin"
26+
rng = np.random.RandomState(0)
27+
rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path)
28+
29+
for count_ in (-1, count // 2):
30+
expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:]))
31+
32+
with open(path, mode) as file:
33+
actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_)
34+
35+
torch.testing.assert_close(actual, expected)
36+
37+
38+
def test_read_flo(tmpdir):
39+
path = tmpdir / "test.flo"
40+
make_fake_flo_file(3, 4, path)
41+
42+
with open(path, "rb") as file:
43+
actual = read_flo(file)
44+
45+
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
46+
47+
torch.testing.assert_close(actual, expected)

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import abc
2-
import codecs
32
import functools
43
import io
54
import operator
65
import pathlib
76
import string
8-
import sys
9-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
7+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO
108

119
import torch
1210
from torchdata.datapipes.iter import (
@@ -30,6 +28,7 @@
3028
image_buffer_from_array,
3129
Decompressor,
3230
INFINITE_BUFFER_SIZE,
31+
fromfile,
3332
)
3433
from torchvision.prototype.features import Image, Label
3534

@@ -50,50 +49,33 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
5049
}
5150

5251
def __init__(
53-
self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int]
52+
self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int]
5453
) -> None:
5554
self.datapipe = datapipe
5655
self.start = start
5756
self.stop = stop
5857

59-
@staticmethod
60-
def _decode(input: bytes) -> int:
61-
return int(codecs.encode(input, "hex"), 16)
62-
63-
@staticmethod
64-
def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor:
65-
# As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to
66-
# avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the
67-
# contents, which would otherwise fail.
68-
chunk = bytearray(chunk)
69-
if reverse_bytes:
70-
chunk.reverse()
71-
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
72-
else:
73-
tensor = torch.frombuffer(chunk, dtype=dtype)
74-
return tensor.reshape(shape)
75-
7658
def __iter__(self) -> Iterator[torch.Tensor]:
7759
for _, file in self.datapipe:
78-
magic = self._decode(file.read(4))
60+
read = functools.partial(fromfile, file, byte_order="big")
61+
62+
magic = int(read(dtype=torch.int32, count=1))
7963
dtype = self._DTYPE_MAP[magic // 256]
8064
ndim = magic % 256 - 1
8165

82-
num_samples = self._decode(file.read(4))
83-
shape = [self._decode(file.read(4)) for _ in range(ndim)]
84-
85-
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
86-
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
87-
# we need to reverse the bytes before we can read them with torch.frombuffer().
88-
reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1
89-
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
66+
num_samples = int(read(dtype=torch.int32, count=1))
67+
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
68+
count = prod(shape) if shape else 1
9069

9170
start = self.start or 0
9271
stop = min(self.stop, num_samples) if self.stop else num_samples
9372

94-
file.seek(start * chunk_size, 1)
73+
if start:
74+
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
75+
file.seek(num_bytes_per_value * count * start, 1)
76+
9577
for _ in range(stop - start):
96-
yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes)
78+
yield read(dtype=dtype, count=count).reshape(shape)
9779

9880

9981
class _MNISTBase(Dataset):

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import gzip
44
import io
55
import lzma
6+
import mmap
67
import os
78
import os.path
89
import pathlib
910
import pickle
11+
from typing import BinaryIO
1012
from typing import (
1113
Sequence,
1214
Callable,
@@ -24,6 +26,7 @@
2426

2527
import numpy as np
2628
import PIL.Image
29+
import torch
2730
import torch.distributed as dist
2831
import torch.utils.data
2932
from torch.utils.data import IterDataPipe
@@ -43,6 +46,8 @@
4346
"path_accessor",
4447
"path_comparator",
4548
"Decompressor",
49+
"fromfile",
50+
"read_flo",
4651
]
4752

4853
K = TypeVar("K")
@@ -253,3 +258,66 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
253258
# dp = dp.cycle(2)
254259
dp = TakerDataPipe(dp, dataset_size)
255260
return dp
261+
262+
263+
def fromfile(
264+
file: BinaryIO,
265+
*,
266+
dtype: torch.dtype,
267+
byte_order: str,
268+
count: int = -1,
269+
) -> torch.Tensor:
270+
"""Construct a tensor from a binary file.
271+
272+
.. note::
273+
274+
This function is similar to :func:`numpy.fromfile` with two notable differences:
275+
276+
1. This function only accepts an open binary file, but not a path to it.
277+
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
278+
concept.
279+
280+
.. note::
281+
282+
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
283+
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
284+
285+
Args:
286+
file (IO): Open binary file.
287+
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
288+
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
289+
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
290+
"""
291+
byte_order = "<" if byte_order == "little" else ">"
292+
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
293+
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
294+
np_dtype = byte_order + char + str(item_size)
295+
296+
# PyTorch does not support tensors with underlying read-only memory. In case
297+
# - the file has a .fileno(),
298+
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
299+
# - the file is seekable
300+
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to
301+
# a mutable location afterwards.
302+
buffer: Union[memoryview, bytearray]
303+
try:
304+
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
305+
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
306+
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
307+
except (PermissionError, io.UnsupportedOperation):
308+
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
309+
buffer = bytearray(file.read(-1 if count == -1 else count * item_size))
310+
311+
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
312+
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
313+
# successive .astype() call.
314+
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
315+
316+
317+
def read_flo(file: BinaryIO) -> torch.Tensor:
318+
if file.read(4) != b"PIEH":
319+
raise ValueError("Magic number incorrect. Invalid .flo file")
320+
321+
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
322+
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
323+
return flow.reshape((height, width, 2)).permute((2, 0, 1))

0 commit comments

Comments
 (0)