|
8 | 8 | import os.path
|
9 | 9 | import pathlib
|
10 | 10 | import pickle
|
| 11 | +import platform |
11 | 12 | from typing import BinaryIO
|
12 | 13 | from typing import (
|
13 | 14 | Sequence,
|
@@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
|
260 | 261 | return dp
|
261 | 262 |
|
262 | 263 |
|
| 264 | +def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: |
| 265 | + # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable |
| 266 | + return bytearray(file.read(-1 if count == -1 else count * item_size)) |
| 267 | + |
| 268 | + |
263 | 269 | def fromfile(
|
264 | 270 | file: BinaryIO,
|
265 | 271 | *,
|
@@ -293,20 +299,24 @@ def fromfile(
|
293 | 299 | item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
|
294 | 300 | np_dtype = byte_order + char + str(item_size)
|
295 | 301 |
|
296 |
| - # PyTorch does not support tensors with underlying read-only memory. In case |
297 |
| - # - the file has a .fileno(), |
298 |
| - # - the file was opened for updating, i.e. 'r+b' or 'w+b', |
299 |
| - # - the file is seekable |
300 |
| - # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to |
301 |
| - # a mutable location afterwards. |
302 | 302 | buffer: Union[memoryview, bytearray]
|
303 |
| - try: |
304 |
| - buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] |
305 |
| - # Reading from the memoryview does not advance the file cursor, so we have to do it manually. |
306 |
| - file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) |
307 |
| - except (PermissionError, io.UnsupportedOperation): |
308 |
| - # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable |
309 |
| - buffer = bytearray(file.read(-1 if count == -1 else count * item_size)) |
| 303 | + if platform.system() != "Windows": |
| 304 | + # PyTorch does not support tensors with underlying read-only memory. In case |
| 305 | + # - the file has a .fileno(), |
| 306 | + # - the file was opened for updating, i.e. 'r+b' or 'w+b', |
| 307 | + # - the file is seekable |
| 308 | + # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it |
| 309 | + # to a mutable location afterwards. |
| 310 | + try: |
| 311 | + buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] |
| 312 | + # Reading from the memoryview does not advance the file cursor, so we have to do it manually. |
| 313 | + file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) |
| 314 | + except (PermissionError, io.UnsupportedOperation): |
| 315 | + buffer = _read_mutable_buffer_fallback(file, count, item_size) |
| 316 | + else: |
| 317 | + # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state |
| 318 | + # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. |
| 319 | + buffer = _read_mutable_buffer_fallback(file, count, item_size) |
310 | 320 |
|
311 | 321 | # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
|
312 | 322 | # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
|
|
0 commit comments