diff --git a/Lib/bz2.py b/Lib/bz2.py index ce07ebeb142d92..7447d12fc4b8c8 100644 --- a/Lib/bz2.py +++ b/Lib/bz2.py @@ -226,15 +226,23 @@ def write(self, data): """Write a byte string to the file. Returns the number of uncompressed bytes written, which is - always len(data). Note that due to buffering, the file on disk - may not reflect the data written until close() is called. + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until close() + is called. """ with self._lock: self._check_can_write() + if isinstance(data, (bytes, bytearray)): + length = len(data) + else: + # accept any data that supports the buffer protocol + data = memoryview(data) + length = data.nbytes + compressed = self._compressor.compress(data) self._fp.write(compressed) - self._pos += len(data) - return len(data) + self._pos += length + return length def writelines(self, seq): """Write a sequence of byte strings to the file. diff --git a/Lib/lzma.py b/Lib/lzma.py index 0817b872d2019f..0aa30fe87f8c0c 100644 --- a/Lib/lzma.py +++ b/Lib/lzma.py @@ -225,14 +225,22 @@ def write(self, data): """Write a bytes object to the file. Returns the number of uncompressed bytes written, which is - always len(data). Note that due to buffering, the file on disk - may not reflect the data written until close() is called. + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until close() + is called. """ self._check_can_write() + if isinstance(data, (bytes, bytearray)): + length = len(data) + else: + # accept any data that supports the buffer protocol + data = memoryview(data) + length = data.nbytes + compressed = self._compressor.compress(data) self._fp.write(compressed) - self._pos += len(data) - return len(data) + self._pos += length + return length def seek(self, offset, whence=io.SEEK_SET): """Change the file position. diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index 1fce9d82d25d6a..c84f70ebb094a6 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -1,6 +1,7 @@ from test import support from test.support import bigmemtest, _4G +import array import unittest from io import BytesIO, DEFAULT_BUFFER_SIZE import os @@ -618,6 +619,14 @@ def test_read_truncated(self): with BZ2File(BytesIO(truncated[:i])) as f: self.assertRaises(EOFError, f.read, 1) + def test_issue44439(self): + q = array.array('Q', [1, 2, 3, 4, 5]) + LENGTH = len(q) * q.itemsize + + with BZ2File(BytesIO(), 'w') as f: + self.assertEqual(f.write(q), LENGTH) + self.assertEqual(f.tell(), LENGTH) + class BZ2CompressorTest(BaseTest): def testCompress(self): diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py index 0f3af27efa909c..ef7dd6325d49f3 100644 --- a/Lib/test/test_lzma.py +++ b/Lib/test/test_lzma.py @@ -1,4 +1,5 @@ import _compression +import array from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE import os import pathlib @@ -1227,6 +1228,14 @@ def test_issue21872(self): self.assertTrue(d2.eof) self.assertEqual(out1 + out2, entire) + def test_issue44439(self): + q = array.array('Q', [1, 2, 3, 4, 5]) + LENGTH = len(q) * q.itemsize + + with LZMAFile(BytesIO(), 'w') as f: + self.assertEqual(f.write(q), LENGTH) + self.assertEqual(f.tell(), LENGTH) + class OpenTestCase(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst b/Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst new file mode 100644 index 00000000000000..27396683700a83 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst @@ -0,0 +1,3 @@ +Fix in :meth:`bz2.BZ2File.write` / :meth:`lzma.LZMAFile.write` methods, when +the input data is an object that supports the buffer protocol, the file length +may be wrong.