|
10 | 10 | from io import (
|
11 | 11 | BufferedIOBase,
|
12 | 12 | BytesIO,
|
| 13 | + FileIO, |
13 | 14 | RawIOBase,
|
14 | 15 | StringIO,
|
15 | 16 | TextIOBase,
|
|
19 | 20 | import os
|
20 | 21 | from pathlib import Path
|
21 | 22 | import re
|
| 23 | +import tarfile |
22 | 24 | from typing import (
|
23 | 25 | IO,
|
24 | 26 | Any,
|
@@ -450,13 +452,18 @@ def file_path_to_url(path: str) -> str:
|
450 | 452 | return urljoin("file:", pathname2url(path))
|
451 | 453 |
|
452 | 454 |
|
453 |
| -_compression_to_extension = { |
454 |
| - "gzip": ".gz", |
455 |
| - "bz2": ".bz2", |
456 |
| - "zip": ".zip", |
457 |
| - "xz": ".xz", |
458 |
| - "zstd": ".zst", |
| 455 | +_extension_to_compression = { |
| 456 | + ".tar": "tar", |
| 457 | + ".tar.gz": "tar", |
| 458 | + ".tar.bz2": "tar", |
| 459 | + ".tar.xz": "tar", |
| 460 | + ".gz": "gzip", |
| 461 | + ".bz2": "bz2", |
| 462 | + ".zip": "zip", |
| 463 | + ".xz": "xz", |
| 464 | + ".zst": "zstd", |
459 | 465 | }
|
| 466 | +_supported_compressions = set(_extension_to_compression.values()) |
460 | 467 |
|
461 | 468 |
|
462 | 469 | def get_compression_method(
|
@@ -532,20 +539,18 @@ def infer_compression(
|
532 | 539 | return None
|
533 | 540 |
|
534 | 541 | # Infer compression from the filename/URL extension
|
535 |
| - for compression, extension in _compression_to_extension.items(): |
| 542 | + for extension, compression in _extension_to_compression.items(): |
536 | 543 | if filepath_or_buffer.lower().endswith(extension):
|
537 | 544 | return compression
|
538 | 545 | return None
|
539 | 546 |
|
540 | 547 | # Compression has been specified. Check that it's valid
|
541 |
| - if compression in _compression_to_extension: |
| 548 | + if compression in _supported_compressions: |
542 | 549 | return compression
|
543 | 550 |
|
544 | 551 | # https://github.com/python/mypy/issues/5492
|
545 | 552 | # Unsupported operand types for + ("List[Optional[str]]" and "List[str]")
|
546 |
| - valid = ["infer", None] + sorted( |
547 |
| - _compression_to_extension |
548 |
| - ) # type: ignore[operator] |
| 553 | + valid = ["infer", None] + sorted(_supported_compressions) # type: ignore[operator] |
549 | 554 | msg = (
|
550 | 555 | f"Unrecognized compression type: {compression}\n"
|
551 | 556 | f"Valid compression types are {valid}"
|
@@ -682,7 +687,7 @@ def get_handle(
|
682 | 687 | ioargs.encoding,
|
683 | 688 | ioargs.mode,
|
684 | 689 | errors,
|
685 |
| - ioargs.compression["method"] not in _compression_to_extension, |
| 690 | + ioargs.compression["method"] not in _supported_compressions, |
686 | 691 | )
|
687 | 692 |
|
688 | 693 | is_path = isinstance(handle, str)
|
@@ -753,6 +758,30 @@ def get_handle(
|
753 | 758 | f"Only one file per ZIP: {zip_names}"
|
754 | 759 | )
|
755 | 760 |
|
| 761 | + # TAR Encoding |
| 762 | + elif compression == "tar": |
| 763 | + if "mode" not in compression_args: |
| 764 | + compression_args["mode"] = ioargs.mode |
| 765 | + if is_path: |
| 766 | + handle = _BytesTarFile.open(name=handle, **compression_args) |
| 767 | + else: |
| 768 | + handle = _BytesTarFile.open(fileobj=handle, **compression_args) |
| 769 | + assert isinstance(handle, _BytesTarFile) |
| 770 | + if handle.mode == "r": |
| 771 | + handles.append(handle) |
| 772 | + files = handle.getnames() |
| 773 | + if len(files) == 1: |
| 774 | + file = handle.extractfile(files[0]) |
| 775 | + assert file is not None |
| 776 | + handle = file |
| 777 | + elif len(files) == 0: |
| 778 | + raise ValueError(f"Zero files found in TAR archive {path_or_buf}") |
| 779 | + else: |
| 780 | + raise ValueError( |
| 781 | + "Multiple files found in TAR archive. " |
| 782 | + f"Only one file per TAR archive: {files}" |
| 783 | + ) |
| 784 | + |
756 | 785 | # XZ Compression
|
757 | 786 | elif compression == "xz":
|
758 | 787 | handle = get_lzma_file()(handle, ioargs.mode)
|
@@ -844,6 +873,116 @@ def get_handle(
|
844 | 873 | )
|
845 | 874 |
|
846 | 875 |
|
| 876 | +# error: Definition of "__exit__" in base class "TarFile" is incompatible with |
| 877 | +# definition in base class "BytesIO" [misc] |
| 878 | +# error: Definition of "__enter__" in base class "TarFile" is incompatible with |
| 879 | +# definition in base class "BytesIO" [misc] |
| 880 | +# error: Definition of "__enter__" in base class "TarFile" is incompatible with |
| 881 | +# definition in base class "BinaryIO" [misc] |
| 882 | +# error: Definition of "__enter__" in base class "TarFile" is incompatible with |
| 883 | +# definition in base class "IO" [misc] |
| 884 | +# error: Definition of "read" in base class "TarFile" is incompatible with |
| 885 | +# definition in base class "BytesIO" [misc] |
| 886 | +# error: Definition of "read" in base class "TarFile" is incompatible with |
| 887 | +# definition in base class "IO" [misc] |
| 888 | +class _BytesTarFile(tarfile.TarFile, BytesIO): # type: ignore[misc] |
| 889 | + """ |
| 890 | + Wrapper for standard library class TarFile and allow the returned file-like |
| 891 | + handle to accept byte strings via `write` method. |
| 892 | +
|
| 893 | + BytesIO provides attributes of file-like object and TarFile.addfile writes |
| 894 | + bytes strings into a member of the archive. |
| 895 | + """ |
| 896 | + |
| 897 | + # GH 17778 |
| 898 | + def __init__( |
| 899 | + self, |
| 900 | + name: str | bytes | os.PathLike[str] | os.PathLike[bytes], |
| 901 | + mode: Literal["r", "a", "w", "x"], |
| 902 | + fileobj: FileIO, |
| 903 | + archive_name: str | None = None, |
| 904 | + **kwargs, |
| 905 | + ): |
| 906 | + self.archive_name = archive_name |
| 907 | + self.multiple_write_buffer: BytesIO | None = None |
| 908 | + self._closing = False |
| 909 | + |
| 910 | + super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs) |
| 911 | + |
| 912 | + @classmethod |
| 913 | + def open(cls, name=None, mode="r", **kwargs): |
| 914 | + mode = mode.replace("b", "") |
| 915 | + return super().open(name=name, mode=cls.extend_mode(name, mode), **kwargs) |
| 916 | + |
| 917 | + @classmethod |
| 918 | + def extend_mode( |
| 919 | + cls, name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], mode: str |
| 920 | + ) -> str: |
| 921 | + if mode != "w": |
| 922 | + return mode |
| 923 | + if isinstance(name, (os.PathLike, str)): |
| 924 | + filename = Path(name) |
| 925 | + if filename.suffix == ".gz": |
| 926 | + return mode + ":gz" |
| 927 | + elif filename.suffix == ".xz": |
| 928 | + return mode + ":xz" |
| 929 | + elif filename.suffix == ".bz2": |
| 930 | + return mode + ":bz2" |
| 931 | + return mode |
| 932 | + |
| 933 | + def infer_filename(self): |
| 934 | + """ |
| 935 | + If an explicit archive_name is not given, we still want the file inside the zip |
| 936 | + file not to be named something.tar, because that causes confusion (GH39465). |
| 937 | + """ |
| 938 | + if isinstance(self.name, (os.PathLike, str)): |
| 939 | + # error: Argument 1 to "Path" has |
| 940 | + # incompatible type "Union[str, PathLike[str], PathLike[bytes]]"; |
| 941 | + # expected "Union[str, PathLike[str]]" [arg-type] |
| 942 | + filename = Path(self.name) # type: ignore[arg-type] |
| 943 | + if filename.suffix == ".tar": |
| 944 | + return filename.with_suffix("").name |
| 945 | + if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]: |
| 946 | + return filename.with_suffix("").with_suffix("").name |
| 947 | + return filename.name |
| 948 | + return None |
| 949 | + |
| 950 | + def write(self, data): |
| 951 | + # buffer multiple write calls, write on flush |
| 952 | + if self.multiple_write_buffer is None: |
| 953 | + self.multiple_write_buffer = BytesIO() |
| 954 | + self.multiple_write_buffer.write(data) |
| 955 | + |
| 956 | + def flush(self) -> None: |
| 957 | + # write to actual handle and close write buffer |
| 958 | + if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: |
| 959 | + return |
| 960 | + |
| 961 | + # TarFile needs a non-empty string |
| 962 | + archive_name = self.archive_name or self.infer_filename() or "tar" |
| 963 | + with self.multiple_write_buffer: |
| 964 | + value = self.multiple_write_buffer.getvalue() |
| 965 | + tarinfo = tarfile.TarInfo(name=archive_name) |
| 966 | + tarinfo.size = len(value) |
| 967 | + self.addfile(tarinfo, BytesIO(value)) |
| 968 | + |
| 969 | + def close(self): |
| 970 | + self.flush() |
| 971 | + super().close() |
| 972 | + |
| 973 | + @property |
| 974 | + def closed(self): |
| 975 | + if self.multiple_write_buffer is None: |
| 976 | + return False |
| 977 | + return self.multiple_write_buffer.closed and super().closed |
| 978 | + |
| 979 | + @closed.setter |
| 980 | + def closed(self, value): |
| 981 | + if not self._closing and value: |
| 982 | + self._closing = True |
| 983 | + self.close() |
| 984 | + |
| 985 | + |
847 | 986 | # error: Definition of "__exit__" in base class "ZipFile" is incompatible with
|
848 | 987 | # definition in base class "BytesIO" [misc]
|
849 | 988 | # error: Definition of "__enter__" in base class "ZipFile" is incompatible with
|
|
0 commit comments