From 5a2314d58e60e2f4a6d06e909a779bfca6ff75c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sun, 26 Jun 2022 12:25:12 -0400 Subject: [PATCH] TYP: some return annotations in pytables.py --- pandas/io/pytables.py | 164 ++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 71 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index c20ce0c847b61..b96fa4a57f188 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -19,9 +19,11 @@ Any, Callable, Hashable, + Iterator, Literal, Sequence, cast, + overload, ) import warnings @@ -592,7 +594,7 @@ def __init__( self._filters = None self.open(mode=mode, **kwargs) - def __fspath__(self): + def __fspath__(self) -> str: return self._path @property @@ -603,16 +605,16 @@ def root(self): return self._handle.root @property - def filename(self): + def filename(self) -> str: return self._path def __getitem__(self, key: str): return self.get(key) - def __setitem__(self, key: str, value): + def __setitem__(self, key: str, value) -> None: self.put(key, value) - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: return self.remove(key) def __getattr__(self, name: str): @@ -644,10 +646,10 @@ def __repr__(self) -> str: pstr = pprint_thing(self._path) return f"{type(self)}\nFile path: {pstr}\n" - def __enter__(self): + def __enter__(self) -> HDFStore: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.close() def keys(self, include: str = "pandas") -> list[str]: @@ -684,7 +686,7 @@ def keys(self, include: str = "pandas") -> list[str]: f"`include` should be either 'pandas' or 'native' but is '{include}'" ) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self.keys()) def items(self): @@ -706,7 +708,7 @@ def iteritems(self): ) yield from self.items() - def open(self, mode: str = "a", **kwargs): + def open(self, mode: str = "a", **kwargs) -> None: """ Open the file in the specified mode @@ -751,7 +753,7 @@ def open(self, mode: str = "a", **kwargs): self._handle = tables.open_file(self._path, self._mode, **kwargs) - def close(self): + def close(self) -> None: """ Close the PyTables file handle """ @@ -768,7 +770,7 @@ def is_open(self) -> bool: return False return bool(self._handle.isopen) - def flush(self, fsync: bool = False): + def flush(self, fsync: bool = False) -> None: """ Force all buffered modifications to be written to disk. @@ -1096,7 +1098,7 @@ def put( errors: str = "strict", track_times: bool = True, dropna: bool = False, - ): + ) -> None: """ Store object in HDFStore. @@ -1152,7 +1154,7 @@ def put( dropna=dropna, ) - def remove(self, key: str, where=None, start=None, stop=None): + def remove(self, key: str, where=None, start=None, stop=None) -> None: """ Remove pandas object partially by specifying the where condition @@ -1228,7 +1230,7 @@ def append( data_columns: Literal[True] | list[str] | None = None, encoding=None, errors: str = "strict", - ): + ) -> None: """ Append to Table in file. Node must already exist and be Table format. @@ -1305,7 +1307,7 @@ def append_to_multiple( axes=None, dropna=False, **kwargs, - ): + ) -> None: """ Append to multiple tables @@ -1399,7 +1401,7 @@ def create_table_index( columns=None, optlevel: int | None = None, kind: str | None = None, - ): + ) -> None: """ Create a pytables index on the table. @@ -1545,7 +1547,7 @@ def copy( complevel: int | None = None, fletcher32: bool = False, overwrite=True, - ): + ) -> HDFStore: """ Copy the existing store to a new file, updating in place. @@ -1933,7 +1935,7 @@ def __iter__(self): self.close() - def close(self): + def close(self) -> None: if self.auto_close: self.store.close() @@ -2037,7 +2039,7 @@ def itemsize(self) -> int: def kind_attr(self) -> str: return f"{self.name}_kind" - def set_pos(self, pos: int): + def set_pos(self, pos: int) -> None: """set the position of this column in the Table""" self.pos = pos if pos is not None and self.typ is not None: @@ -2072,7 +2074,9 @@ def is_indexed(self) -> bool: return False return getattr(self.table.cols, self.cname).is_indexed - def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): + def convert( + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[np.ndarray, np.ndarray] | tuple[DatetimeIndex, DatetimeIndex]: """ Convert the data from this selection to the appropriate pandas type. """ @@ -2140,7 +2144,7 @@ def cvalues(self): def __iter__(self): return iter(self.values) - def maybe_set_size(self, min_itemsize=None): + def maybe_set_size(self, min_itemsize=None) -> None: """ maybe set a string col itemsize: min_itemsize can be an integer or a dict with this columns name @@ -2153,10 +2157,10 @@ def maybe_set_size(self, min_itemsize=None): if min_itemsize is not None and self.typ.itemsize < min_itemsize: self.typ = _tables().StringCol(itemsize=min_itemsize, pos=self.pos) - def validate_names(self): + def validate_names(self) -> None: pass - def validate_and_set(self, handler: AppendableTable, append: bool): + def validate_and_set(self, handler: AppendableTable, append: bool) -> None: self.table = handler.table self.validate_col() self.validate_attr(append) @@ -2183,7 +2187,7 @@ def validate_col(self, itemsize=None): return None - def validate_attr(self, append: bool): + def validate_attr(self, append: bool) -> None: # check for backwards incompatibility if append: existing_kind = getattr(self.attrs, self.kind_attr, None) @@ -2192,7 +2196,7 @@ def validate_attr(self, append: bool): f"incompatible kind in col [{existing_kind} - {self.kind}]" ) - def update_info(self, info): + def update_info(self, info) -> None: """ set/update the info for this indexable with the key/value if there is a conflict raise/warn as needed @@ -2225,17 +2229,17 @@ def update_info(self, info): if value is not None or existing_value is not None: idx[key] = value - def set_info(self, info): + def set_info(self, info) -> None: """set my state from the passed info""" idx = info.get(self.name) if idx is not None: self.__dict__.update(idx) - def set_attr(self): + def set_attr(self) -> None: """set the kind for this column""" setattr(self.attrs, self.kind_attr, self.kind) - def validate_metadata(self, handler: AppendableTable): + def validate_metadata(self, handler: AppendableTable) -> None: """validate that kind=category does not change the categories""" if self.meta == "category": new_metadata = self.metadata @@ -2250,7 +2254,7 @@ def validate_metadata(self, handler: AppendableTable): "different categories to the existing" ) - def write_metadata(self, handler: AppendableTable): + def write_metadata(self, handler: AppendableTable) -> None: """set the meta data""" if self.metadata is not None: handler.write_metadata(self.cname, self.metadata) @@ -2263,7 +2267,13 @@ class GenericIndexCol(IndexCol): def is_indexed(self) -> bool: return False - def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): + # error: Return type "Tuple[Int64Index, Int64Index]" of "convert" + # incompatible with return type "Union[Tuple[ndarray[Any, Any], + # ndarray[Any, Any]], Tuple[DatetimeIndex, DatetimeIndex]]" in + # supertype "IndexCol" + def convert( # type: ignore[override] + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[Int64Index, Int64Index]: """ Convert the data from this selection to the appropriate pandas type. @@ -2276,12 +2286,10 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): """ assert isinstance(values, np.ndarray), type(values) - # error: Incompatible types in assignment (expression has type - # "Int64Index", variable has type "ndarray") - values = Int64Index(np.arange(len(values))) # type: ignore[assignment] - return values, values + index = Int64Index(np.arange(len(values))) + return index, index - def set_attr(self): + def set_attr(self) -> None: pass @@ -2362,7 +2370,7 @@ def __eq__(self, other: Any) -> bool: for a in ["name", "cname", "dtype", "pos"] ) - def set_data(self, data: ArrayLike): + def set_data(self, data: ArrayLike) -> None: assert data is not None assert self.dtype is None @@ -2448,7 +2456,7 @@ def cvalues(self): """return my cython values""" return self.data - def validate_attr(self, append): + def validate_attr(self, append) -> None: """validate that we have the same order as the existing & same dtype""" if append: existing_fields = getattr(self.attrs, self.kind_attr, None) @@ -2562,7 +2570,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): return self.values, converted - def set_attr(self): + def set_attr(self) -> None: """set the data for this column""" setattr(self.attrs, self.kind_attr, self.values) setattr(self.attrs, self.meta_attr, self.meta) @@ -2575,7 +2583,7 @@ class DataIndexableCol(DataCol): is_data_indexable = True - def validate_names(self): + def validate_names(self) -> None: if not Index(self.values).is_object(): # TODO: should the message here be more specifically non-str? raise ValueError("cannot have non-object label DataIndexableCol") @@ -2672,12 +2680,12 @@ def __repr__(self) -> str: return f"{self.pandas_type:12.12} (shape->{s})" return self.pandas_type - def set_object_info(self): + def set_object_info(self) -> None: """set my pandas type & version""" self.attrs.pandas_type = str(self.pandas_kind) self.attrs.pandas_version = str(_version) - def copy(self): + def copy(self) -> Fixed: new_self = copy.copy(self) return new_self @@ -2709,11 +2717,11 @@ def _fletcher32(self) -> bool: def attrs(self): return self.group._v_attrs - def set_attrs(self): + def set_attrs(self) -> None: """set our object attributes""" pass - def get_attrs(self): + def get_attrs(self) -> None: """get our object attributes""" pass @@ -2730,17 +2738,17 @@ def is_exists(self) -> bool: def nrows(self): return getattr(self.storable, "nrows", None) - def validate(self, other): + def validate(self, other) -> Literal[True] | None: """validate against an existing storable""" if other is None: - return + return None return True - def validate_version(self, where=None): + def validate_version(self, where=None) -> None: """are we trying to operate on an old version?""" - return True + pass - def infer_axes(self): + def infer_axes(self) -> bool: """ infer the axes of my storer return a boolean indicating if we have a valid storer or not @@ -2767,7 +2775,9 @@ def write(self, **kwargs): "cannot write on an abstract storer: subclasses should implement" ) - def delete(self, where=None, start: int | None = None, stop: int | None = None): + def delete( + self, where=None, start: int | None = None, stop: int | None = None + ) -> None: """ support fully deleting the node in its entirety (only) - where specification must be None @@ -2842,7 +2852,7 @@ def f(values, freq=None, tz=None): return factory, kwargs - def validate_read(self, columns, where): + def validate_read(self, columns, where) -> None: """ raise if any keywords are passed which are not-None """ @@ -2861,12 +2871,12 @@ def validate_read(self, columns, where): def is_exists(self) -> bool: return True - def set_attrs(self): + def set_attrs(self) -> None: """set our object attributes""" self.attrs.encoding = self.encoding self.attrs.errors = self.errors - def get_attrs(self): + def get_attrs(self) -> None: """retrieve our attributes""" self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None)) self.errors = _ensure_decoded(getattr(self.attrs, "errors", "strict")) @@ -2924,7 +2934,7 @@ def read_index( else: # pragma: no cover raise TypeError(f"unrecognized index variety: {variety}") - def write_index(self, key: str, index: Index): + def write_index(self, key: str, index: Index) -> None: if isinstance(index, MultiIndex): setattr(self.attrs, f"{key}_variety", "multi") self.write_multi_index(key, index) @@ -2947,7 +2957,7 @@ def write_index(self, key: str, index: Index): if isinstance(index, DatetimeIndex) and index.tz is not None: node._v_attrs.tz = _get_tz(index.tz) - def write_multi_index(self, key: str, index: MultiIndex): + def write_multi_index(self, key: str, index: MultiIndex) -> None: setattr(self.attrs, f"{key}_nlevels", index.nlevels) for i, (lev, level_codes, name) in enumerate( @@ -3033,7 +3043,7 @@ def read_index_node( return index - def write_array_empty(self, key: str, value: ArrayLike): + def write_array_empty(self, key: str, value: ArrayLike) -> None: """write a 0-len array""" # ugly hack for length 0 axes arr = np.empty((1,) * value.ndim) @@ -3152,7 +3162,7 @@ def read( columns=None, start: int | None = None, stop: int | None = None, - ): + ) -> Series: self.validate_read(columns, where) index = self.read_index("index", start=start, stop=stop) values = self.read_array("values", start=start, stop=stop) @@ -3203,7 +3213,7 @@ def read( columns=None, start: int | None = None, stop: int | None = None, - ): + ) -> DataFrame: # start, stop applied to rows, so 0th axis only self.validate_read(columns, where) select_axis = self.obj_type()._get_block_manager_axis(0) @@ -3352,7 +3362,7 @@ def __getitem__(self, c: str): return a return None - def validate(self, other): + def validate(self, other) -> None: """validate against an existing table""" if other is None: return @@ -3449,7 +3459,7 @@ def is_transposed(self) -> bool: return False @property - def data_orientation(self): + def data_orientation(self) -> tuple[int, ...]: """return a tuple of my permutated axes, non_indexable at the front""" return tuple( itertools.chain( @@ -3488,7 +3498,7 @@ def _get_metadata_path(self, key: str) -> str: group = self.group._v_pathname return f"{group}/meta/{key}/meta" - def write_metadata(self, key: str, values: np.ndarray): + def write_metadata(self, key: str, values: np.ndarray) -> None: """ Write out a metadata array to the key as a fixed-format Series. @@ -3512,7 +3522,7 @@ def read_metadata(self, key: str): return self.parent.select(self._get_metadata_path(key)) return None - def set_attrs(self): + def set_attrs(self) -> None: """set our table type & indexables""" self.attrs.table_type = str(self.table_type) self.attrs.index_cols = self.index_cols() @@ -3525,7 +3535,7 @@ def set_attrs(self): self.attrs.levels = self.levels self.attrs.info = self.info - def get_attrs(self): + def get_attrs(self) -> None: """retrieve our attributes""" self.non_index_axes = getattr(self.attrs, "non_index_axes", None) or [] self.data_columns = getattr(self.attrs, "data_columns", None) or [] @@ -3537,14 +3547,14 @@ def get_attrs(self): self.index_axes = [a for a in self.indexables if a.is_an_indexable] self.values_axes = [a for a in self.indexables if not a.is_an_indexable] - def validate_version(self, where=None): + def validate_version(self, where=None) -> None: """are we trying to operate on an old version?""" if where is not None: if self.version[0] <= 0 and self.version[1] <= 10 and self.version[2] < 1: ws = incompatibility_doc % ".".join([str(x) for x in self.version]) warnings.warn(ws, IncompatibilityWarning) - def validate_min_itemsize(self, min_itemsize): + def validate_min_itemsize(self, min_itemsize) -> None: """ validate the min_itemsize doesn't contain items that are not in the axes this needs data_columns to be defined @@ -3642,7 +3652,9 @@ def f(i, c): return _indexables - def create_index(self, columns=None, optlevel=None, kind: str | None = None): + def create_index( + self, columns=None, optlevel=None, kind: str | None = None + ) -> None: """ Create a pytables index on the specified columns. @@ -4100,7 +4112,7 @@ def get_blk_items(mgr): return blocks, blk_items - def process_axes(self, obj, selection: Selection, columns=None): + def process_axes(self, obj, selection: Selection, columns=None) -> DataFrame: """process axes filters""" # make a copy to avoid side effects if columns is not None: @@ -4354,7 +4366,7 @@ def write( # add the rows table.write_data(chunksize, dropna=dropna) - def write_data(self, chunksize: int | None, dropna: bool = False): + def write_data(self, chunksize: int | None, dropna: bool = False) -> None: """ we form the data into a 2-d including indexes,values,mask write chunk-by-chunk """ @@ -4419,7 +4431,7 @@ def write_data_chunk( indexes: list[np.ndarray], mask: npt.NDArray[np.bool_] | None, values: list[np.ndarray], - ): + ) -> None: """ Parameters ---------- @@ -4701,7 +4713,7 @@ def pandas_type(self) -> str: def storable(self): return getattr(self.group, "table", None) or self.group - def get_attrs(self): + def get_attrs(self) -> None: """retrieve our attributes""" self.non_index_axes = [] self.nan_rep = None @@ -4823,10 +4835,20 @@ def _get_tz(tz: tzinfo) -> str | tzinfo: return zone +@overload +def _set_tz( + values: np.ndarray | Index, tz: str | tzinfo, coerce: bool = False +) -> DatetimeIndex: + ... + + +@overload +def _set_tz(values: np.ndarray | Index, tz: None, coerce: bool = False) -> np.ndarray: + ... + + def _set_tz( - values: np.ndarray | Index, - tz: str | tzinfo | None, - coerce: bool = False, + values: np.ndarray | Index, tz: str | tzinfo | None, coerce: bool = False ) -> np.ndarray | DatetimeIndex: """ coerce the values to a DatetimeIndex if tz is set