From b2245ecd5abc5200798c70b4c0b8d82623708f05 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 14:50:00 +0300 Subject: [PATCH 1/8] CLN: StataReader: refactor repeated struct.unpack/read calls to helpers --- pandas/io/stata.py | 140 ++++++++++++++++++++++----------------------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index eacc036f2740d..4eb497ab43aa8 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1198,9 +1198,42 @@ def _set_encoding(self) -> None: else: self._encoding = "utf-8" + def _read_int8(self) -> int: + return struct.unpack("b", self.path_or_buf.read(1))[0] + + def _read_uint8(self) -> int: + return struct.unpack("B", self.path_or_buf.read(1))[0] + + def _read_uint16(self) -> int: + return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0] + + def _read_uint32(self) -> int: + return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0] + + def _read_uint64(self) -> int: + return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0] + + def _read_int16(self) -> int: + return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0] + + def _read_int32(self) -> int: + return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0] + + def _read_int64(self) -> int: + return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0] + + def _read_char8(self) -> bytes: + return struct.unpack("c", self.path_or_buf.read(1))[0] + + def _read_int16_count(self, count: int) -> tuple[int, ...]: + return struct.unpack( + f"{self.byteorder}{'h' * count}", + self.path_or_buf.read(2 * count), + ) + def _read_header(self) -> None: - first_char = self.path_or_buf.read(1) - if struct.unpack("c", first_char)[0] == b"<": + first_char = self._read_char8() + if first_char == b"<": self._read_new_header() else: self._read_old_header(first_char) @@ -1220,11 +1253,9 @@ def _read_new_header(self) -> None: self.path_or_buf.read(21) # self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<" self.path_or_buf.read(15) # - nvar_type = "H" if self.format_version <= 118 else "I" - nvar_size = 2 if self.format_version <= 118 else 4 - self.nvar = struct.unpack( - self.byteorder + nvar_type, self.path_or_buf.read(nvar_size) - )[0] + self.nvar = ( + self._read_uint16() if self.format_version <= 118 else self._read_uint32() + ) self.path_or_buf.read(7) # self.nobs = self._get_nobs() @@ -1236,35 +1267,19 @@ def _read_new_header(self) -> None: self.path_or_buf.read(8) # 0x0000000000000000 self.path_or_buf.read(8) # position of - self._seek_vartypes = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16 - ) - self._seek_varnames = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) - self._seek_sortlist = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) - self._seek_formats = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9 - ) - self._seek_value_label_names = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19 - ) + self._seek_vartypes = self._read_int64() + 16 + self._seek_varnames = self._read_int64() + 10 + self._seek_sortlist = self._read_int64() + 10 + self._seek_formats = self._read_int64() + 9 + self._seek_value_label_names = self._read_int64() + 19 # Requires version-specific treatment self._seek_variable_labels = self._get_seek_variable_labels() self.path_or_buf.read(8) # - self.data_location = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6 - ) - self.seek_strls = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7 - ) - self.seek_value_labels = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14 - ) + self.data_location = self._read_int64() + 6 + self.seek_strls = self._read_int64() + 7 + self.seek_value_labels = self._read_int64() + 14 self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes) @@ -1272,10 +1287,7 @@ def _read_new_header(self) -> None: self.varlist = self._get_varlist() self.path_or_buf.seek(self._seek_sortlist) - self.srtlist = struct.unpack( - self.byteorder + ("h" * (self.nvar + 1)), - self.path_or_buf.read(2 * (self.nvar + 1)), - )[:-1] + self.srtlist = self._read_int16_count(self.nvar + 1)[:-1] self.path_or_buf.seek(self._seek_formats) self.fmtlist = self._get_fmtlist() @@ -1291,10 +1303,7 @@ def _get_dtypes( self, seek_vartypes: int ) -> tuple[list[int | str], list[str | np.dtype]]: self.path_or_buf.seek(seek_vartypes) - raw_typlist = [ - struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] - for _ in range(self.nvar) - ] + raw_typlist = [self._read_uint16() for _ in range(self.nvar)] def f(typ: int) -> int | str: if typ <= 2045: @@ -1363,16 +1372,16 @@ def _get_variable_labels(self) -> list[str]: def _get_nobs(self) -> int: if self.format_version >= 118: - return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + return self._read_uint64() else: - return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + return self._read_uint32() def _get_data_label(self) -> str: if self.format_version >= 118: - strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] + strlen = self._read_uint16() return self._decode(self.path_or_buf.read(strlen)) elif self.format_version == 117: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] + strlen = self._read_int8() return self._decode(self.path_or_buf.read(strlen)) elif self.format_version > 105: return self._decode(self.path_or_buf.read(81)) @@ -1381,10 +1390,10 @@ def _get_data_label(self) -> str: def _get_time_stamp(self) -> str: if self.format_version >= 118: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] + strlen = self._read_int8() return self.path_or_buf.read(strlen).decode("utf-8") elif self.format_version == 117: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] + strlen = self._read_int8() return self._decode(self.path_or_buf.read(strlen)) elif self.format_version > 104: return self._decode(self.path_or_buf.read(18)) @@ -1399,22 +1408,20 @@ def _get_seek_variable_labels(self) -> int: # variable, 20 for the closing tag and 17 for the opening tag return self._seek_value_label_names + (33 * self.nvar) + 20 + 17 elif self.format_version >= 118: - return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17 + return self._read_int64() + 17 else: raise ValueError() def _read_old_header(self, first_char: bytes) -> None: - self.format_version = struct.unpack("b", first_char)[0] + self.format_version = int(first_char[0]) if self.format_version not in [104, 105, 108, 111, 113, 114, 115]: raise ValueError(_version_error.format(version=self.format_version)) self._set_encoding() - self.byteorder = ( - ">" if struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 else "<" - ) - self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0] + self.byteorder = (">" if self._read_int8() == 0x1 else "<") + self.filetype = self._read_int8() self.path_or_buf.read(1) # unused - self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] + self.nvar = self._read_uint16() self.nobs = self._get_nobs() self._data_label = self._get_data_label() @@ -1423,7 +1430,7 @@ def _read_old_header(self, first_char: bytes) -> None: # descriptors if self.format_version > 108: - typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)] + typlist = [int(c) for c in self.path_or_buf.read(self.nvar)] else: buf = self.path_or_buf.read(self.nvar) typlistb = np.frombuffer(buf, dtype=np.uint8) @@ -1453,10 +1460,7 @@ def _read_old_header(self, first_char: bytes) -> None: self.varlist = [ self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar) ] - self.srtlist = struct.unpack( - self.byteorder + ("h" * (self.nvar + 1)), - self.path_or_buf.read(2 * (self.nvar + 1)), - )[:-1] + self.srtlist = self._read_int16_count(self.nvar + 1)[:-1] self.fmtlist = self._get_fmtlist() @@ -1471,17 +1475,11 @@ def _read_old_header(self, first_char: bytes) -> None: if self.format_version > 104: while True: - data_type = struct.unpack( - self.byteorder + "b", self.path_or_buf.read(1) - )[0] + data_type = self._read_int8() if self.format_version > 108: - data_len = struct.unpack( - self.byteorder + "i", self.path_or_buf.read(4) - )[0] + data_len = self._read_int32() else: - data_len = struct.unpack( - self.byteorder + "h", self.path_or_buf.read(2) - )[0] + data_len = self._read_int16() if data_type == 0: break self.path_or_buf.read(data_len) @@ -1565,8 +1563,8 @@ def _read_value_labels(self) -> None: labname = self._decode(self.path_or_buf.read(129)) self.path_or_buf.read(3) # padding - n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + n = self._read_uint32() + txtlen = self._read_uint32() off = np.frombuffer( self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n ) @@ -1594,7 +1592,7 @@ def _read_strls(self) -> None: break if self.format_version == 117: - v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + v_o = self._read_uint64() else: buf = self.path_or_buf.read(12) # Only tested on little endian file on little endian machine. @@ -1605,8 +1603,8 @@ def _read_strls(self) -> None: # This path may not be correct, impossible to test buf = buf[0:v_size] + buf[(4 + v_size) :] v_o = struct.unpack("Q", buf)[0] - typ = struct.unpack("B", self.path_or_buf.read(1))[0] - length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + typ = self._read_uint8() + length = self._read_uint32() va = self.path_or_buf.read(length) if typ == 130: decoded_va = va[0:-1].decode(self._encoding) From e409db9cb440435eeaa4b10916ed0526f9b8f8d7 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 15:04:07 +0300 Subject: [PATCH 2/8] CLN: StataReader: replace string concatenations with f-strings --- pandas/io/stata.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 4eb497ab43aa8..2ae5278c7a6bc 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1496,9 +1496,9 @@ def _setup_dtype(self) -> np.dtype: for i, typ in enumerate(self.typlist): if typ in self.NUMPY_TYPE_MAP: typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP - dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + dtypes.append((f"s{i}", f"{self.byteorder}{self.NUMPY_TYPE_MAP[typ]}")) else: - dtypes.append(("s" + str(i), "S" + str(typ))) + dtypes.append((f"s{i}", f"S{typ}")) self._dtype = np.dtype(dtypes) return self._dtype @@ -1566,10 +1566,10 @@ def _read_value_labels(self) -> None: n = self._read_uint32() txtlen = self._read_uint32() off = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n + self.path_or_buf.read(4 * n), dtype=f"{self.byteorder}i4", count=n ) val = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n + self.path_or_buf.read(4 * n), dtype=f"{self.byteorder}i4", count=n ) ii = np.argsort(off) off = off[ii] From b4db2b48a25cbbb6a07013ebc70daeaf57fc2454 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 15:08:02 +0300 Subject: [PATCH 3/8] CLN: StataReader: prefix internal state with underscore --- pandas/io/stata.py | 390 +++++++++++++++++++++++++-------------------- 1 file changed, 217 insertions(+), 173 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 2ae5278c7a6bc..22385cf0877a8 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1129,7 +1129,7 @@ def __init__( storage_options: StorageOptions = None, ) -> None: super().__init__() - self.col_sizes: list[int] = [] + self._col_sizes: list[int] = [] # Arguments to the reader (can be temporarily overridden in # calls to read). @@ -1167,7 +1167,7 @@ def __init__( compression=compression, ) as handles: # Copy to BytesIO, and ensure no encoding - self.path_or_buf = BytesIO(handles.handle.read()) + self._path_or_buf = BytesIO(handles.handle.read()) self._read_header() self._setup_dtype() @@ -1187,48 +1187,48 @@ def __exit__( def close(self) -> None: """close the handle if its open""" - self.path_or_buf.close() + self._path_or_buf.close() def _set_encoding(self) -> None: """ Set string encoding which depends on file version """ - if self.format_version < 118: + if self._format_version < 118: self._encoding = "latin-1" else: self._encoding = "utf-8" def _read_int8(self) -> int: - return struct.unpack("b", self.path_or_buf.read(1))[0] + return struct.unpack("b", self._path_or_buf.read(1))[0] def _read_uint8(self) -> int: - return struct.unpack("B", self.path_or_buf.read(1))[0] + return struct.unpack("B", self._path_or_buf.read(1))[0] def _read_uint16(self) -> int: - return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0] + return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0] def _read_uint32(self) -> int: - return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0] + return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0] def _read_uint64(self) -> int: - return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0] + return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0] def _read_int16(self) -> int: - return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0] + return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0] def _read_int32(self) -> int: - return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0] + return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0] def _read_int64(self) -> int: - return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0] + return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0] def _read_char8(self) -> bytes: - return struct.unpack("c", self.path_or_buf.read(1))[0] + return struct.unpack("c", self._path_or_buf.read(1))[0] def _read_int16_count(self, count: int) -> tuple[int, ...]: return struct.unpack( - f"{self.byteorder}{'h' * count}", - self.path_or_buf.read(2 * count), + f"{self._byteorder}{'h' * count}", + self._path_or_buf.read(2 * count), ) def _read_header(self) -> None: @@ -1238,34 +1238,34 @@ def _read_header(self) -> None: else: self._read_old_header(first_char) - self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0 + self._has_string_data = len([x for x in self._typlist if type(x) is int]) > 0 # calculate size of a data record - self.col_sizes = [self._calcsize(typ) for typ in self.typlist] + self._col_sizes = [self._calcsize(typ) for typ in self._typlist] def _read_new_header(self) -> None: # The first part of the header is common to 117 - 119. - self.path_or_buf.read(27) # stata_dta>
- self.format_version = int(self.path_or_buf.read(3)) - if self.format_version not in [117, 118, 119]: - raise ValueError(_version_error.format(version=self.format_version)) + self._path_or_buf.read(27) # stata_dta>
+ self._format_version = int(self._path_or_buf.read(3)) + if self._format_version not in [117, 118, 119]: + raise ValueError(_version_error.format(version=self._format_version)) self._set_encoding() - self.path_or_buf.read(21) # - self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<" - self.path_or_buf.read(15) # - self.nvar = ( - self._read_uint16() if self.format_version <= 118 else self._read_uint32() + self._path_or_buf.read(21) # + self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<" + self._path_or_buf.read(15) # + self._nvar = ( + self._read_uint16() if self._format_version <= 118 else self._read_uint32() ) - self.path_or_buf.read(7) # + self._path_or_buf.read(7) # - self.nobs = self._get_nobs() - self.path_or_buf.read(11) #
- self.path_or_buf.read(8) # 0x0000000000000000 - self.path_or_buf.read(8) # position of + self._path_or_buf.read(19) # + self._time_stamp = self._get_time_stamp() + self._path_or_buf.read(26) #
+ self._path_or_buf.read(8) # 0x0000000000000000 + self._path_or_buf.read(8) # position of self._seek_vartypes = self._read_int64() + 16 self._seek_varnames = self._read_int64() + 10 @@ -1276,34 +1276,34 @@ def _read_new_header(self) -> None: # Requires version-specific treatment self._seek_variable_labels = self._get_seek_variable_labels() - self.path_or_buf.read(8) # - self.data_location = self._read_int64() + 6 - self.seek_strls = self._read_int64() + 7 - self.seek_value_labels = self._read_int64() + 14 + self._path_or_buf.read(8) # + self._data_location = self._read_int64() + 6 + self._seek_strls = self._read_int64() + 7 + self._seek_value_labels = self._read_int64() + 14 - self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes) + self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes) - self.path_or_buf.seek(self._seek_varnames) - self.varlist = self._get_varlist() + self._path_or_buf.seek(self._seek_varnames) + self._varlist = self._get_varlist() - self.path_or_buf.seek(self._seek_sortlist) - self.srtlist = self._read_int16_count(self.nvar + 1)[:-1] + self._path_or_buf.seek(self._seek_sortlist) + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] - self.path_or_buf.seek(self._seek_formats) - self.fmtlist = self._get_fmtlist() + self._path_or_buf.seek(self._seek_formats) + self._fmtlist = self._get_fmtlist() - self.path_or_buf.seek(self._seek_value_label_names) - self.lbllist = self._get_lbllist() + self._path_or_buf.seek(self._seek_value_label_names) + self._lbllist = self._get_lbllist() - self.path_or_buf.seek(self._seek_variable_labels) + self._path_or_buf.seek(self._seek_variable_labels) self._variable_labels = self._get_variable_labels() # Get data type information, works for versions 117-119. def _get_dtypes( self, seek_vartypes: int ) -> tuple[list[int | str], list[str | np.dtype]]: - self.path_or_buf.seek(seek_vartypes) - raw_typlist = [self._read_uint16() for _ in range(self.nvar)] + self._path_or_buf.seek(seek_vartypes) + raw_typlist = [self._read_uint16() for _ in range(self._nvar)] def f(typ: int) -> int | str: if typ <= 2045: @@ -1329,110 +1329,110 @@ def g(typ: int) -> str | np.dtype: def _get_varlist(self) -> list[str]: # 33 in order formats, 129 in formats 118 and 119 - b = 33 if self.format_version < 118 else 129 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + b = 33 if self._format_version < 118 else 129 + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] # Returns the format list def _get_fmtlist(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: b = 57 - elif self.format_version > 113: + elif self._format_version > 113: b = 49 - elif self.format_version > 104: + elif self._format_version > 104: b = 12 else: b = 7 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] # Returns the label list def _get_lbllist(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: b = 129 - elif self.format_version > 108: + elif self._format_version > 108: b = 33 else: b = 9 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] def _get_variable_labels(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: vlblist = [ - self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar) ] - elif self.format_version > 105: + elif self._format_version > 105: vlblist = [ - self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar) ] else: vlblist = [ - self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar) ] return vlblist def _get_nobs(self) -> int: - if self.format_version >= 118: + if self._format_version >= 118: return self._read_uint64() else: return self._read_uint32() def _get_data_label(self) -> str: - if self.format_version >= 118: + if self._format_version >= 118: strlen = self._read_uint16() - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version == 117: + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version == 117: strlen = self._read_int8() - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version > 105: - return self._decode(self.path_or_buf.read(81)) + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 105: + return self._decode(self._path_or_buf.read(81)) else: - return self._decode(self.path_or_buf.read(32)) + return self._decode(self._path_or_buf.read(32)) def _get_time_stamp(self) -> str: - if self.format_version >= 118: + if self._format_version >= 118: strlen = self._read_int8() - return self.path_or_buf.read(strlen).decode("utf-8") - elif self.format_version == 117: + return self._path_or_buf.read(strlen).decode("utf-8") + elif self._format_version == 117: strlen = self._read_int8() - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version > 104: - return self._decode(self.path_or_buf.read(18)) + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 104: + return self._decode(self._path_or_buf.read(18)) else: raise ValueError() def _get_seek_variable_labels(self) -> int: - if self.format_version == 117: - self.path_or_buf.read(8) # , throw away + if self._format_version == 117: + self._path_or_buf.read(8) # , throw away # Stata 117 data files do not follow the described format. This is # a work around that uses the previous label, 33 bytes for each # variable, 20 for the closing tag and 17 for the opening tag - return self._seek_value_label_names + (33 * self.nvar) + 20 + 17 - elif self.format_version >= 118: + return self._seek_value_label_names + (33 * self._nvar) + 20 + 17 + elif self._format_version >= 118: return self._read_int64() + 17 else: raise ValueError() def _read_old_header(self, first_char: bytes) -> None: - self.format_version = int(first_char[0]) - if self.format_version not in [104, 105, 108, 111, 113, 114, 115]: - raise ValueError(_version_error.format(version=self.format_version)) + self._format_version = int(first_char[0]) + if self._format_version not in [104, 105, 108, 111, 113, 114, 115]: + raise ValueError(_version_error.format(version=self._format_version)) self._set_encoding() - self.byteorder = (">" if self._read_int8() == 0x1 else "<") - self.filetype = self._read_int8() - self.path_or_buf.read(1) # unused + self._byteorder = ">" if self._read_int8() == 0x1 else "<" + self._filetype = self._read_int8() + self._path_or_buf.read(1) # unused - self.nvar = self._read_uint16() - self.nobs = self._get_nobs() + self._nvar = self._read_uint16() + self._nobs = self._get_nobs() self._data_label = self._get_data_label() - self.time_stamp = self._get_time_stamp() + self._time_stamp = self._get_time_stamp() # descriptors - if self.format_version > 108: - typlist = [int(c) for c in self.path_or_buf.read(self.nvar)] + if self._format_version > 108: + typlist = [int(c) for c in self._path_or_buf.read(self._nvar)] else: - buf = self.path_or_buf.read(self.nvar) + buf = self._path_or_buf.read(self._nvar) typlistb = np.frombuffer(buf, dtype=np.uint8) typlist = [] for tp in typlistb: @@ -1442,29 +1442,29 @@ def _read_old_header(self, first_char: bytes) -> None: typlist.append(tp - 127) # bytes try: - self.typlist = [self.TYPE_MAP[typ] for typ in typlist] + self._typlist = [self.TYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_types = ",".join([str(x) for x in typlist]) raise ValueError(f"cannot convert stata types [{invalid_types}]") from err try: - self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] + self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_dtypes = ",".join([str(x) for x in typlist]) raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err - if self.format_version > 108: - self.varlist = [ - self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar) + if self._format_version > 108: + self._varlist = [ + self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar) ] else: - self.varlist = [ - self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar) + self._varlist = [ + self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar) ] - self.srtlist = self._read_int16_count(self.nvar + 1)[:-1] + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] - self.fmtlist = self._get_fmtlist() + self._fmtlist = self._get_fmtlist() - self.lbllist = self._get_lbllist() + self._lbllist = self._get_lbllist() self._variable_labels = self._get_variable_labels() @@ -1473,19 +1473,19 @@ def _read_old_header(self, first_char: bytes) -> None: # the size of the next read, which you discard. You then continue # like this until you read 5 bytes of zeros. - if self.format_version > 104: + if self._format_version > 104: while True: data_type = self._read_int8() - if self.format_version > 108: + if self._format_version > 108: data_len = self._read_int32() else: data_len = self._read_int16() if data_type == 0: break - self.path_or_buf.read(data_len) + self._path_or_buf.read(data_len) # necessary data to continue parsing - self.data_location = self.path_or_buf.tell() + self._data_location = self._path_or_buf.tell() def _setup_dtype(self) -> np.dtype: """Map between numpy and state dtypes""" @@ -1493,10 +1493,10 @@ def _setup_dtype(self) -> np.dtype: return self._dtype dtypes = [] # Convert struct data types to numpy data type - for i, typ in enumerate(self.typlist): + for i, typ in enumerate(self._typlist): if typ in self.NUMPY_TYPE_MAP: typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP - dtypes.append((f"s{i}", f"{self.byteorder}{self.NUMPY_TYPE_MAP[typ]}")) + dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}")) else: dtypes.append((f"s{i}", f"S{typ}")) self._dtype = np.dtype(dtypes) @@ -1506,7 +1506,7 @@ def _setup_dtype(self) -> np.dtype: def _calcsize(self, fmt: int | str) -> int: if isinstance(fmt, int): return fmt - return struct.calcsize(self.byteorder + fmt) + return struct.calcsize(self._byteorder + fmt) def _decode(self, s: bytes) -> str: # have bytes not strings, so must decode @@ -1533,71 +1533,73 @@ def _read_value_labels(self) -> None: if self._value_labels_read: # Don't read twice return - if self.format_version <= 108: + if self._format_version <= 108: # Value labels are not supported in version 108 and earlier. self._value_labels_read = True - self.value_label_dict: dict[str, dict[float, str]] = {} + self._value_label_dict: dict[str, dict[float, str]] = {} return - if self.format_version >= 117: - self.path_or_buf.seek(self.seek_value_labels) + if self._format_version >= 117: + self._path_or_buf.seek(self._seek_value_labels) else: assert self._dtype is not None - offset = self.nobs * self._dtype.itemsize - self.path_or_buf.seek(self.data_location + offset) + offset = self._nobs * self._dtype.itemsize + self._path_or_buf.seek(self._data_location + offset) self._value_labels_read = True - self.value_label_dict = {} + self._value_label_dict = {} while True: - if self.format_version >= 117: - if self.path_or_buf.read(5) == b" + if self._format_version >= 117: + if self._path_or_buf.read(5) == b" break # end of value label table - slength = self.path_or_buf.read(4) + slength = self._path_or_buf.read(4) if not slength: break # end of value label table (format < 117) - if self.format_version <= 117: - labname = self._decode(self.path_or_buf.read(33)) + if self._format_version <= 117: + labname = self._decode(self._path_or_buf.read(33)) else: - labname = self._decode(self.path_or_buf.read(129)) - self.path_or_buf.read(3) # padding + labname = self._decode(self._path_or_buf.read(129)) + self._path_or_buf.read(3) # padding n = self._read_uint32() txtlen = self._read_uint32() off = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=f"{self.byteorder}i4", count=n + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n ) val = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=f"{self.byteorder}i4", count=n + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n ) ii = np.argsort(off) off = off[ii] val = val[ii] - txt = self.path_or_buf.read(txtlen) - self.value_label_dict[labname] = {} + txt = self._path_or_buf.read(txtlen) + self._value_label_dict[labname] = {} for i in range(n): end = off[i + 1] if i < n - 1 else txtlen - self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end]) - if self.format_version >= 117: - self.path_or_buf.read(6) # + self._value_label_dict[labname][val[i]] = self._decode( + txt[off[i] : end] + ) + if self._format_version >= 117: + self._path_or_buf.read(6) # self._value_labels_read = True def _read_strls(self) -> None: - self.path_or_buf.seek(self.seek_strls) + self._path_or_buf.seek(self._seek_strls) # Wrap v_o in a string to allow uint64 values as keys on 32bit OS self.GSO = {"0": ""} while True: - if self.path_or_buf.read(3) != b"GSO": + if self._path_or_buf.read(3) != b"GSO": break - if self.format_version == 117: + if self._format_version == 117: v_o = self._read_uint64() else: - buf = self.path_or_buf.read(12) + buf = self._path_or_buf.read(12) # Only tested on little endian file on little endian machine. - v_size = 2 if self.format_version == 118 else 3 - if self.byteorder == "<": + v_size = 2 if self._format_version == 118 else 3 + if self._byteorder == "<": buf = buf[0:v_size] + buf[4 : (12 - v_size)] else: # This path may not be correct, impossible to test @@ -1605,7 +1607,7 @@ def _read_strls(self) -> None: v_o = struct.unpack("Q", buf)[0] typ = self._read_uint8() length = self._read_uint32() - va = self.path_or_buf.read(length) + va = self._path_or_buf.read(length) if typ == 130: decoded_va = va[0:-1].decode(self._encoding) else: @@ -1650,11 +1652,11 @@ def read( # Handle empty file or chunk. If reading incrementally raise # StopIteration. If reading the whole thing return an empty # data frame. - if (self.nobs == 0) and (nrows is None): + if (self._nobs == 0) and (nrows is None): self._can_read_value_labels = True self._data_read = True self.close() - return DataFrame(columns=self.varlist) + return DataFrame(columns=self._varlist) # Handle options if convert_dates is None: @@ -1673,16 +1675,16 @@ def read( index_col = self._index_col if nrows is None: - nrows = self.nobs + nrows = self._nobs - if (self.format_version >= 117) and (not self._value_labels_read): + if (self._format_version >= 117) and (not self._value_labels_read): self._can_read_value_labels = True self._read_strls() # Read data assert self._dtype is not None dtype = self._dtype - max_read_len = (self.nobs - self._lines_read) * dtype.itemsize + max_read_len = (self._nobs - self._lines_read) * dtype.itemsize read_len = nrows * dtype.itemsize read_len = min(read_len, max_read_len) if read_len <= 0: @@ -1693,28 +1695,28 @@ def read( self.close() raise StopIteration offset = self._lines_read * dtype.itemsize - self.path_or_buf.seek(self.data_location + offset) - read_lines = min(nrows, self.nobs - self._lines_read) + self._path_or_buf.seek(self._data_location + offset) + read_lines = min(nrows, self._nobs - self._lines_read) raw_data = np.frombuffer( - self.path_or_buf.read(read_len), dtype=dtype, count=read_lines + self._path_or_buf.read(read_len), dtype=dtype, count=read_lines ) self._lines_read += read_lines - if self._lines_read == self.nobs: + if self._lines_read == self._nobs: self._can_read_value_labels = True self._data_read = True # if necessary, swap the byte order to native here - if self.byteorder != self._native_byteorder: + if self._byteorder != self._native_byteorder: raw_data = raw_data.byteswap().newbyteorder() if convert_categoricals: self._read_value_labels() if len(raw_data) == 0: - data = DataFrame(columns=self.varlist) + data = DataFrame(columns=self._varlist) else: data = DataFrame.from_records(raw_data) - data.columns = Index(self.varlist) + data.columns = Index(self._varlist) # If index is not specified, use actual row number rather than # restarting at 0 for each chunk. @@ -1730,25 +1732,25 @@ def read( raise # Decode strings - for col, typ in zip(data, self.typlist): + for col, typ in zip(data, self._typlist): if type(typ) is int: data[col] = data[col].apply(self._decode, convert_dtype=True) data = self._insert_strls(data) - cols_ = np.where([dtyp is not None for dtyp in self.dtyplist])[0] + cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0] # Convert columns (if needed) to match input type ix = data.index requires_type_conversion = False data_formatted = [] for i in cols_: - if self.dtyplist[i] is not None: + if self._dtyplist[i] is not None: col = data.columns[i] dtype = data[col].dtype - if dtype != np.dtype(object) and dtype != self.dtyplist[i]: + if dtype != np.dtype(object) and dtype != self._dtyplist[i]: requires_type_conversion = True data_formatted.append( - (col, Series(data[col], ix, self.dtyplist[i])) + (col, Series(data[col], ix, self._dtyplist[i])) ) else: data_formatted.append((col, data[col])) @@ -1763,20 +1765,20 @@ def read( def any_startswith(x: str) -> bool: return any(x.startswith(fmt) for fmt in _date_formats) - cols = np.where([any_startswith(x) for x in self.fmtlist])[0] + cols = np.where([any_startswith(x) for x in self._fmtlist])[0] for i in cols: col = data.columns[i] try: data[col] = _stata_elapsed_date_to_datetime_vec( - data[col], self.fmtlist[i] + data[col], self._fmtlist[i] ) except ValueError: self.close() raise - if convert_categoricals and self.format_version > 108: + if convert_categoricals and self._format_version > 108: data = self._do_convert_categoricals( - data, self.value_label_dict, self.lbllist, order_categoricals + data, self._value_label_dict, self._lbllist, order_categoricals ) if not preserve_dtypes: @@ -1807,7 +1809,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra # Check for missing values, and replace if found replacements = {} for i, colname in enumerate(data): - fmt = self.typlist[i] + fmt = self._typlist[i] if fmt not in self.VALID_RANGE: continue @@ -1853,7 +1855,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra def _insert_strls(self, data: DataFrame) -> DataFrame: if not hasattr(self, "GSO") or len(self.GSO) == 0: return data - for i, typ in enumerate(self.typlist): + for i, typ in enumerate(self._typlist): if typ != "Q": continue # Wrap v_o in a string to allow uint64 values as keys on 32bit OS @@ -1879,15 +1881,15 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra lbllist = [] for col in columns: i = data.columns.get_loc(col) - dtyplist.append(self.dtyplist[i]) - typlist.append(self.typlist[i]) - fmtlist.append(self.fmtlist[i]) - lbllist.append(self.lbllist[i]) - - self.dtyplist = dtyplist - self.typlist = typlist - self.fmtlist = fmtlist - self.lbllist = lbllist + dtyplist.append(self._dtyplist[i]) + typlist.append(self._typlist[i]) + fmtlist.append(self._fmtlist[i]) + lbllist.append(self._lbllist[i]) + + self._dtyplist = dtyplist + self._typlist = typlist + self._fmtlist = fmtlist + self._lbllist = lbllist self._column_selector_set = True return data[columns] @@ -1976,6 +1978,48 @@ def data_label(self) -> str: """ return self._data_label + @property + def typlist(self) -> list[int | str]: + """ + Return list of variable types. + """ + return self._typlist + + @property + def dtyplist(self) -> list[str | np.dtype]: + """ + Return list of variable types. + """ + return self._dtyplist + + @property + def lbllist(self) -> list[str]: + """ + Return list of variable labels. + """ + return self._lbllist + + @property + def varlist(self) -> list[str]: + """ + Return list of variable names. + """ + return self._varlist + + @property + def fmtlist(self) -> list[str]: + """ + Return list of variable formats. + """ + return self._fmtlist + + @property + def time_stamp(self) -> str: + """ + Return time stamp of Stata file. + """ + return self._time_stamp + def variable_labels(self) -> dict[str, str]: """ Return a dict associating each variable name with corresponding label. @@ -1984,7 +2028,7 @@ def variable_labels(self) -> dict[str, str]: ------- dict """ - return dict(zip(self.varlist, self._variable_labels)) + return dict(zip(self._varlist, self._variable_labels)) def value_labels(self) -> dict[str, dict[float, str]]: """ @@ -1997,7 +2041,7 @@ def value_labels(self) -> dict[str, dict[float, str]]: if not self._value_labels_read: self._read_value_labels() - return self.value_label_dict + return self._value_label_dict @Appender(_read_stata_doc) From 4602aeddaebbbd44896197c58c93302d783dfcbe Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 15:16:19 +0300 Subject: [PATCH 4/8] FIX: StataReader: defer opening file to when data is required --- pandas/io/stata.py | 64 ++++++++++++++--------------------- pandas/tests/io/test_stata.py | 6 ++-- 2 files changed, 28 insertions(+), 42 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 22385cf0877a8..c66f79d919dae 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1114,6 +1114,8 @@ def __init__(self) -> None: class StataReader(StataParser, abc.Iterator): __doc__ = _stata_reader_doc + _path_or_buf: IO[bytes] # Will be assigned by `_open_file`. + def __init__( self, path_or_buf: FilePath | ReadBuffer[bytes], @@ -1140,6 +1142,9 @@ def __init__( self._preserve_dtypes = preserve_dtypes self._columns = columns self._order_categoricals = order_categoricals + self._original_path_or_buf = path_or_buf + self._compression = compression + self._storage_options = storage_options self._encoding = "" self._chunksize = chunksize self._using_iterator = False @@ -1149,6 +1154,7 @@ def __init__( raise ValueError("chunksize must be a positive integer when set.") # State variables for the file + self._close_file: Callable[[], None] | None = None self._has_string_data = False self._missing_values = False self._can_read_value_labels = False @@ -1159,12 +1165,24 @@ def __init__( self._lines_read = 0 self._native_byteorder = _set_endianness(sys.byteorder) + + def _ensure_open(self) -> None: + """ + Ensure the file has been opened and its header data read. + """ + if not hasattr(self, "_path_or_buf"): + self._open_file() + + def _open_file(self) -> None: + """ + Open the file (with compression options, etc.), and read header information. + """ with get_handle( - path_or_buf, + self._original_path_or_buf, "rb", - storage_options=storage_options, + storage_options=self._storage_options, is_text=False, - compression=compression, + compression=self._compression, ) as handles: # Copy to BytesIO, and ensure no encoding self._path_or_buf = BytesIO(handles.handle.read()) @@ -1530,6 +1548,7 @@ def _decode(self, s: bytes) -> str: return s.decode("latin-1") def _read_value_labels(self) -> None: + self._ensure_open() if self._value_labels_read: # Don't read twice return @@ -1649,6 +1668,7 @@ def read( columns: Sequence[str] | None = None, order_categoricals: bool | None = None, ) -> DataFrame: + self._ensure_open() # Handle empty file or chunk. If reading incrementally raise # StopIteration. If reading the whole thing return an empty # data frame. @@ -1976,48 +1996,15 @@ def data_label(self) -> str: """ Return data label of Stata file. """ + self._ensure_open() return self._data_label - @property - def typlist(self) -> list[int | str]: - """ - Return list of variable types. - """ - return self._typlist - - @property - def dtyplist(self) -> list[str | np.dtype]: - """ - Return list of variable types. - """ - return self._dtyplist - - @property - def lbllist(self) -> list[str]: - """ - Return list of variable labels. - """ - return self._lbllist - - @property - def varlist(self) -> list[str]: - """ - Return list of variable names. - """ - return self._varlist - - @property - def fmtlist(self) -> list[str]: - """ - Return list of variable formats. - """ - return self._fmtlist - @property def time_stamp(self) -> str: """ Return time stamp of Stata file. """ + self._ensure_open() return self._time_stamp def variable_labels(self) -> dict[str, str]: @@ -2028,6 +2015,7 @@ def variable_labels(self) -> dict[str, str]: ------- dict """ + self._ensure_open() return dict(zip(self._varlist, self._variable_labels)) def value_labels(self) -> dict[str, dict[float, str]]: diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 5393a15cff19b..b6b037674c689 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -736,10 +736,8 @@ def test_minimal_size_col(self): original.to_stata(path, write_index=False) with StataReader(path) as sr: - typlist = sr.typlist - variables = sr.varlist - formats = sr.fmtlist - for variable, fmt, typ in zip(variables, formats, typlist): + sr._ensure_open() # The `_*list` variables are initialized here + for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist): assert int(variable[1:]) == int(fmt[1:-1]) assert int(variable[1:]) == typ From d72d5f91d05be140cdebf5128293daa6ba299bf0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 15:17:10 +0300 Subject: [PATCH 5/8] FIX: StataReader: don't buffer entire file into memory unless necessary Refs #48922 --- doc/source/whatsnew/v2.0.0.rst | 2 ++ pandas/io/stata.py | 47 ++++++++++++++++++++++++++++------ pandas/tests/io/test_stata.py | 38 +++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 8 deletions(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index bdbde438217b9..f1206ddab71a2 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -1163,6 +1163,8 @@ Performance improvements - Fixed a reference leak in :func:`read_hdf` (:issue:`37441`) - Fixed a memory leak in :meth:`DataFrame.to_json` and :meth:`Series.to_json` when serializing datetimes and timedeltas (:issue:`40443`) - Decreased memory usage in many :class:`DataFrameGroupBy` methods (:issue:`51090`) +- Memory improvement in :class:`StataReader` when reading seekable files (:issue:`48922`) + .. --------------------------------------------------------------------------- .. _whatsnew_200.bug_fixes: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index c66f79d919dae..920ba93826c15 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -23,6 +23,7 @@ TYPE_CHECKING, Any, AnyStr, + Callable, Final, Hashable, Sequence, @@ -1148,6 +1149,7 @@ def __init__( self._encoding = "" self._chunksize = chunksize self._using_iterator = False + self._entered = False if self._chunksize is None: self._chunksize = 1 elif not isinstance(chunksize, int) or chunksize <= 0: @@ -1177,21 +1179,36 @@ def _open_file(self) -> None: """ Open the file (with compression options, etc.), and read header information. """ - with get_handle( + if not self._entered: + warnings.warn( + "StataReader is being used without using a context manager. " + "Using StataReader as a context manager is the only supported method.", + ResourceWarning, + stacklevel=find_stack_level(), + ) + handles = get_handle( self._original_path_or_buf, "rb", storage_options=self._storage_options, is_text=False, compression=self._compression, - ) as handles: - # Copy to BytesIO, and ensure no encoding - self._path_or_buf = BytesIO(handles.handle.read()) + ) + if hasattr(handles.handle, "seekable") and handles.handle.seekable(): + # If the handle is directly seekable, use it without an extra copy. + self._path_or_buf = handles.handle + self._close_file = handles.close + else: + # Copy to memory, and ensure no encoding. + with handles: + self._path_or_buf = BytesIO(handles.handle.read()) + self._close_file = self._path_or_buf.close self._read_header() self._setup_dtype() def __enter__(self) -> StataReader: """enter context manager""" + self._entered = True return self def __exit__( @@ -1200,12 +1217,26 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """exit context manager""" - self.close() + if self._close_file: + self._close_file() def close(self) -> None: - """close the handle if its open""" - self._path_or_buf.close() + """Close the handle if its open. + + .. deprecated: 2.0.0 + + The close method is not part of the public API. + The only supported way to use StataReader is to use it as a context manager. + """ + warnings.warn( + "The StataReader.close() method is not part of the public API and " + "will be removed in a future version without notice. " + "Using StataReader as a context manager is the only supported method.", + FutureWarning, + stacklevel=2, + ) + if self._close_file: + self._close_file() def _set_encoding(self) -> None: """ diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index b6b037674c689..75e9f7b744caa 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -1889,6 +1889,44 @@ def test_backward_compat(version, datapath): tm.assert_frame_equal(old_dta, expected, check_dtype=False) +def test_direct_read(datapath, monkeypatch): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + + # Test that opening a file path doesn't buffer the file. + with StataReader(file_path) as reader: + # Must not have been buffered to memory + assert not reader.read().empty + assert not isinstance(reader._path_or_buf, io.BytesIO) + + # Test that we use a given fp exactly, if possible. + with open(file_path, "rb") as fp: + with StataReader(fp) as reader: + assert not reader.read().empty + assert reader._path_or_buf is fp + + # Test that we use a given BytesIO exactly, if possible. + with open(file_path, "rb") as fp: + with io.BytesIO(fp.read()) as bio: + with StataReader(bio) as reader: + assert not reader.read().empty + assert reader._path_or_buf is bio + + +def test_statareader_warns_when_used_without_context(datapath): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + with tm.assert_produces_warning( + ResourceWarning, + match="without using a context manager", + ): + sr = StataReader(file_path) + sr.read() + with tm.assert_produces_warning( + FutureWarning, + match="is not part of the public API", + ): + sr.close() + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) @pytest.mark.parametrize("use_dict", [True, False]) @pytest.mark.parametrize("infer", [True, False]) From f34b04c9fc7054f52aada24b8e4d8361a9c476aa Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 13:35:50 +0300 Subject: [PATCH 6/8] DOC: Note that StataReaders are context managers --- doc/source/user_guide/io.rst | 8 ++++++++ pandas/io/stata.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/doc/source/user_guide/io.rst b/doc/source/user_guide/io.rst index 91cd3335d9db6..3c3a655626bb6 100644 --- a/doc/source/user_guide/io.rst +++ b/doc/source/user_guide/io.rst @@ -6033,6 +6033,14 @@ values will have ``object`` data type. ``int64`` for all integer types and ``float64`` for floating point data. By default, the Stata data types are preserved when importing. +.. note:: + + All :class:`~pandas.io.stata.StataReader` objects, whether created by :func:`~pandas.read_stata` + (when using ``iterator=True`` or ``chunksize``) or instantiated by hand, must be used as context + managers (e.g. the ``with`` statement). + While the :meth:`~pandas.io.stata.StataReader.close` method is available, its use is unsupported. + It is not part of the public API and will be removed in with future without warning. + .. ipython:: python :suppress: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 920ba93826c15..a7447f78a1330 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -183,10 +183,10 @@ >>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP >>> df.to_stata('filename.dta') # doctest: +SKIP ->>> itr = pd.read_stata('filename.dta', chunksize=10000) # doctest: +SKIP ->>> for chunk in itr: -... # Operate on a single chunk, e.g., chunk.mean() -... pass # doctest: +SKIP +>>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP +>>> for chunk in itr: +... # Operate on a single chunk, e.g., chunk.mean() +... pass # doctest: +SKIP """ _read_method_doc = f"""\ From 2f8df0b2c6b949a5bcbb1a420c66e10699c1a34f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 21 Oct 2022 15:37:02 +0300 Subject: [PATCH 7/8] FIX: StataReader: don't close stream implicitly --- pandas/io/stata.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index a7447f78a1330..1aa725dea6b92 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1706,7 +1706,6 @@ def read( if (self._nobs == 0) and (nrows is None): self._can_read_value_labels = True self._data_read = True - self.close() return DataFrame(columns=self._varlist) # Handle options @@ -1743,7 +1742,6 @@ def read( # we are reading the file incrementally if convert_categoricals: self._read_value_labels() - self.close() raise StopIteration offset = self._lines_read * dtype.itemsize self._path_or_buf.seek(self._data_location + offset) @@ -1776,11 +1774,7 @@ def read( data.index = Index(rng) # set attr instead of set_index to avoid copy if columns is not None: - try: - data = self._do_select_columns(data, columns) - except ValueError: - self.close() - raise + data = self._do_select_columns(data, columns) # Decode strings for col, typ in zip(data, self._typlist): @@ -1819,13 +1813,9 @@ def any_startswith(x: str) -> bool: cols = np.where([any_startswith(x) for x in self._fmtlist])[0] for i in cols: col = data.columns[i] - try: - data[col] = _stata_elapsed_date_to_datetime_vec( - data[col], self._fmtlist[i] - ) - except ValueError: - self.close() - raise + data[col] = _stata_elapsed_date_to_datetime_vec( + data[col], self._fmtlist[i] + ) if convert_categoricals and self._format_version > 108: data = self._do_convert_categoricals( From b71a0bc746cdb0a05c16af7c56c31d8d417a762d Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 23 Feb 2023 16:16:32 +0200 Subject: [PATCH 8/8] Apply review changes --- doc/source/whatsnew/v2.0.0.rst | 1 + pandas/io/stata.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index f1206ddab71a2..a8d6f3fce5bb7 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -857,6 +857,7 @@ Deprecations - Deprecated :meth:`Series.backfill` in favor of :meth:`Series.bfill` (:issue:`33396`) - Deprecated :meth:`DataFrame.pad` in favor of :meth:`DataFrame.ffill` (:issue:`33396`) - Deprecated :meth:`DataFrame.backfill` in favor of :meth:`DataFrame.bfill` (:issue:`33396`) +- Deprecated :meth:`~pandas.io.stata.StataReader.close`. Use :class:`~pandas.io.stata.StataReader` as a context manager instead (:issue:`49228`) .. --------------------------------------------------------------------------- .. _whatsnew_200.prior_deprecations: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 1aa725dea6b92..5cc13892224c5 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1233,7 +1233,7 @@ def close(self) -> None: "will be removed in a future version without notice. " "Using StataReader as a context manager is the only supported method.", FutureWarning, - stacklevel=2, + stacklevel=find_stack_level(), ) if self._close_file: self._close_file()