Skip to content

Commit 934ce3b

Browse files
authored
fix MNIST byte flipping (#7081)
* fix MNIST byte flipping * add test * move to utils * remove lazy import
1 parent 372f4fa commit 934ce3b

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

test/test_datasets_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import zipfile
88

99
import pytest
10+
import torch
1011
import torchvision.datasets.utils as utils
12+
from common_utils import assert_equal
1113
from torch._utils_internal import get_file_path_2
1214
from torchvision.datasets.folder import make_dataset
1315
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
@@ -215,6 +217,24 @@ def test_verify_str_arg(self):
215217
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
216218
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
217219

220+
@pytest.mark.parametrize(
221+
("dtype", "actual_hex", "expected_hex"),
222+
[
223+
(torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
224+
(torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
225+
(torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
226+
(torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
227+
],
228+
)
229+
def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
230+
def to_tensor(hex):
231+
return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)
232+
233+
assert_equal(
234+
utils._flip_byte_order(to_tensor(actual_hex)),
235+
to_tensor(expected_hex),
236+
)
237+
218238

219239
@pytest.mark.parametrize(
220240
("kwargs", "expected_error_msg"),

torchvision/datasets/mnist.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from PIL import Image
1414

15-
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
15+
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
1616
from .vision import VisionDataset
1717

1818

@@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
519519
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
520520
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
521521

522-
num_bytes_per_value = torch.iinfo(torch_type).bits // 8
523-
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
524-
# we need to reverse the bytes before we can read them with torch.frombuffer().
525-
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
526522
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
527-
if needs_byte_reversal:
528-
parsed = parsed.flip(0)
523+
524+
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
525+
# that is little endian and the dtype has more than one byte, we need to flip them.
526+
if sys.byteorder == "little" and parsed.element_size() > 1:
527+
parsed = _flip_byte_order(parsed)
529528

530529
assert parsed.shape[0] == np.prod(s) or not strict
531530
return parsed.view(*s)

torchvision/datasets/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
520520
data = np.flip(data, axis=1) # flip on h dimension
521521
data = data[:slice_channels, :, :]
522522
return data.astype(np.float32)
523+
524+
525+
def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
526+
return (
527+
t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
528+
)

0 commit comments

Comments
 (0)