diff --git a/.gitignore b/.gitignore index 56828fa1d9331..e85da9c9b976b 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,9 @@ coverage_html_report # hypothesis test database .hypothesis/ __pycache__ +# pytest-monkeytype +monkeytype.sqlite3 + # OS generated files # ###################### diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 11b69064723c5..c5b0b7b222e27 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -6,16 +6,33 @@ from functools import partial from io import StringIO from shutil import get_terminal_size -from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + TextIO, + Tuple, + Type, + Union, + cast, +) from unicodedata import east_asian_width +from dateutil.tz.tz import tzutc +from dateutil.zoneinfo import tzfile import numpy as np +from numpy import float64, int32, ndarray from pandas._config.config import get_option, set_option from pandas._libs import lib from pandas._libs.tslib import format_array_from_datetime from pandas._libs.tslibs import NaT, Timedelta, Timestamp, iNaT +from pandas._libs.tslibs.nattype import NaTType from pandas.core.dtypes.common import ( is_categorical_dtype, @@ -40,10 +57,14 @@ ) from pandas.core.dtypes.missing import isna, notna +from pandas._typing import FilePathOrBuffer +from pandas.core.arrays.datetimes import DatetimeArray +from pandas.core.arrays.timedeltas import TimedeltaArray from pandas.core.base import PandasObject import pandas.core.common as com from pandas.core.index import Index, ensure_index from pandas.core.indexes.datetimes import DatetimeIndex +from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas.io.common import _expand_user, _stringify_path from pandas.io.formats.printing import adjoin, justify, pprint_thing @@ -51,6 +72,11 @@ if TYPE_CHECKING: from pandas import Series, DataFrame, Categorical +formatters_type = Union[ + List[Callable], Tuple[Callable, ...], Dict[Union[str, int], Callable] +] +float_format_type = Union[str, Callable, "EngFormatter"] + common_docstring = """ Parameters ---------- @@ -66,11 +92,11 @@ Whether to print index (row) labels. na_rep : str, optional, default 'NaN' String representation of NAN to use. - formatters : list or dict of one-param. functions, optional + formatters : list, tuple or dict of one-param. functions, optional Formatter functions to apply to columns' elements by position or name. The result of each function must be a unicode string. - List must be of length equal to the number of columns. + List/tuple must be of length equal to the number of columns. float_format : one-parameter function, optional, default None Formatter function to apply to columns' elements if they are floats. The result of this function must be a unicode string. @@ -354,13 +380,13 @@ class TextAdjustment: def __init__(self): self.encoding = get_option("display.encoding") - def len(self, text): + def len(self, text: str) -> int: return len(text) - def justify(self, texts, max_len, mode="right"): + def justify(self, texts: Any, max_len: int, mode: str = "right") -> List[str]: return justify(texts, max_len, mode=mode) - def adjoin(self, space, *lists, **kwargs): + def adjoin(self, space: int, *lists, **kwargs) -> str: return adjoin(space, *lists, strlen=self.len, justfunc=self.justify, **kwargs) @@ -377,7 +403,7 @@ def __init__(self): # Ambiguous width can be changed by option self._EAW_MAP = {"Na": 1, "N": 1, "W": 2, "F": 2, "H": 1} - def len(self, text): + def len(self, text: str) -> int: """ Calculate display width considering unicode East Asian Width """ @@ -388,7 +414,9 @@ def len(self, text): self._EAW_MAP.get(east_asian_width(c), self.ambiguous_width) for c in text ) - def justify(self, texts, max_len, mode="right"): + def justify( + self, texts: Iterable[str], max_len: int, mode: str = "right" + ) -> List[str]: # re-calculate padding space per str considering East Asian Width def _get_pad(t): return max_len - self.len(t) + len(t) @@ -401,7 +429,7 @@ def _get_pad(t): return [x.rjust(_get_pad(x)) for x in texts] -def _get_adjustment(): +def _get_adjustment() -> TextAdjustment: use_east_asian_width = get_option("display.unicode.east_asian_width") if use_east_asian_width: return EastAsianTextAdjustment() @@ -411,17 +439,21 @@ def _get_adjustment(): class TableFormatter: - show_dimensions = None + show_dimensions = None # type: bool + is_truncated = None # type: bool + formatters = None # type: formatters_type + columns = None # type: Index @property - def should_show_dimensions(self): + def should_show_dimensions(self) -> Optional[bool]: return self.show_dimensions is True or ( self.show_dimensions == "truncate" and self.is_truncated ) - def _get_formatter(self, i): + def _get_formatter(self, i: Union[str, int]) -> Optional[Callable]: if isinstance(self.formatters, (list, tuple)): if is_integer(i): + i = cast(int, i) return self.formatters[i] else: return None @@ -446,26 +478,26 @@ class DataFrameFormatter(TableFormatter): def __init__( self, - frame, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - justify=None, - float_format=None, - sparsify=None, - index_names=True, - line_width=None, - max_rows=None, - min_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - table_id=None, - render_links=False, + frame: "DataFrame", + buf: Optional[FilePathOrBuffer] = None, + columns: Optional[List[str]] = None, + col_space: Optional[Union[str, int]] = None, + header: Union[bool, List[str]] = True, + index: bool = True, + na_rep: str = "NaN", + formatters: Optional[formatters_type] = None, + justify: Optional[str] = None, + float_format: Optional[float_format_type] = None, + sparsify: Optional[bool] = None, + index_names: bool = True, + line_width: Optional[int] = None, + max_rows: Optional[int] = None, + min_rows: Optional[int] = None, + max_cols: Optional[int] = None, + show_dimensions: bool = False, + decimal: str = ".", + table_id: Optional[str] = None, + render_links: bool = False, **kwds ): self.frame = frame @@ -532,9 +564,12 @@ def _chk_truncate(self) -> None: prompt_row = 1 if self.show_dimensions: show_dimension_rows = 3 + # assume we only get here if self.header is boolean. + # i.e. not to_latex() where self.header may be List[str] + self.header = cast(bool, self.header) n_add_rows = self.header + dot_row + show_dimension_rows + prompt_row # rows available to fill with actual data - max_rows_adj = self.h - n_add_rows + max_rows_adj = self.h - n_add_rows # type: Optional[int] self.max_rows_adj = max_rows_adj # Format only rows and columns that could potentially fit the @@ -561,9 +596,12 @@ def _chk_truncate(self) -> None: frame = self.frame if truncate_h: + # cast here since if truncate_h is True, max_cols_adj is not None + max_cols_adj = cast(int, max_cols_adj) if max_cols_adj == 0: col_num = len(frame.columns) elif max_cols_adj == 1: + max_cols = cast(int, max_cols) frame = frame.iloc[:, :max_cols] col_num = max_cols else: @@ -573,6 +611,8 @@ def _chk_truncate(self) -> None: ) self.tr_col_num = col_num if truncate_v: + # cast here since if truncate_v is True, max_rows_adj is not None + max_rows_adj = cast(int, max_rows_adj) if max_rows_adj == 1: row_num = max_rows frame = frame.iloc[:max_rows, :] @@ -586,12 +626,16 @@ def _chk_truncate(self) -> None: self.tr_frame = frame self.truncate_h = truncate_h self.truncate_v = truncate_v - self.is_truncated = self.truncate_h or self.truncate_v + self.is_truncated = bool(self.truncate_h or self.truncate_v) def _to_str_columns(self) -> List[List[str]]: """ Render a DataFrame to a list of columns (as lists of strings). """ + # this method is not used by to_html where self.col_space + # could be a string so safe to cast + self.col_space = cast(int, self.col_space) + frame = self.tr_frame # may include levels names also @@ -610,6 +654,8 @@ def _to_str_columns(self) -> List[List[str]]: stringified.append(fmt_values) else: if is_list_like(self.header): + # cast here since can't be bool if is_list_like + self.header = cast(List[str], self.header) if len(self.header) != len(self.columns): raise ValueError( ( @@ -656,6 +702,8 @@ def _to_str_columns(self) -> List[List[str]]: if truncate_v: n_header_rows = len(str_index) - len(frame) row_num = self.tr_row_num + # cast here since if truncate_v is True, self.tr_row_num is not None + row_num = cast(int, row_num) for ix, col in enumerate(strcols): # infer from above row cwidth = self.adj.len(strcols[ix][row_num]) @@ -704,8 +752,8 @@ def to_string(self) -> None: ): # need to wrap around text = self._join_multiline(*strcols) else: # max_cols == 0. Try to fit frame to terminal - text = self.adj.adjoin(1, *strcols).split("\n") - max_len = Series(text).str.len().max() + lines = self.adj.adjoin(1, *strcols).split("\n") + max_len = Series(lines).str.len().max() # plus truncate dot col dif = max_len - self.w # '+ 1' to avoid too wide repr (GH PR #17023) @@ -742,10 +790,10 @@ def to_string(self) -> None: ) ) - def _join_multiline(self, *strcols): + def _join_multiline(self, *args) -> str: lwidth = self.line_width adjoin_width = 1 - strcols = list(strcols) + strcols = list(args) if self.index: idx = strcols.pop(0) lwidth -= np.array([self.adj.len(x) for x in idx]).max() + adjoin_width @@ -758,6 +806,8 @@ def _join_multiline(self, *strcols): nbins = len(col_bins) if self.truncate_v: + # cast here since if truncate_v is True, max_rows_adj is not None + self.max_rows_adj = cast(int, self.max_rows_adj) nrows = self.max_rows_adj + 1 else: nrows = len(self.frame) @@ -779,13 +829,13 @@ def _join_multiline(self, *strcols): def to_latex( self, - column_format=None, - longtable=False, - encoding=None, - multicolumn=False, - multicolumn_format=None, - multirow=False, - ): + column_format: Optional[str] = None, + longtable: bool = False, + encoding: Optional[str] = None, + multicolumn: bool = False, + multicolumn_format: Optional[str] = None, + multirow: bool = False, + ) -> None: """ Render a DataFrame to a LaTeX tabular/longtable environment output. """ @@ -920,7 +970,8 @@ def show_col_idx_names(self) -> bool: def _get_formatted_index(self, frame: "DataFrame") -> List[str]: # Note: this is only used by to_string() and to_latex(), not by - # to_html(). + # to_html(). so safe to cast col_space here. + self.col_space = cast(int, self.col_space) index = frame.index columns = frame.columns fmt = self._get_formatter("__index__") @@ -972,16 +1023,16 @@ def _get_column_name_list(self) -> List[str]: def format_array( - values, - formatter, - float_format=None, - na_rep="NaN", - digits=None, - space=None, - justify="right", - decimal=".", - leading_space=None, -): + values: Any, + formatter: Optional[Callable], + float_format: Optional[float_format_type] = None, + na_rep: str = "NaN", + digits: Optional[int] = None, + space: Optional[Union[str, int]] = None, + justify: str = "right", + decimal: str = ".", + leading_space: Optional[bool] = None, +) -> List[str]: """ Format an array for printing. @@ -1010,7 +1061,7 @@ def format_array( """ if is_datetime64_dtype(values.dtype): - fmt_klass = Datetime64Formatter + fmt_klass = Datetime64Formatter # type: Type[GenericArrayFormatter] elif is_datetime64tz_dtype(values): fmt_klass = Datetime64TZFormatter elif is_timedelta64_dtype(values.dtype): @@ -1051,17 +1102,17 @@ def format_array( class GenericArrayFormatter: def __init__( self, - values, - digits=7, - formatter=None, - na_rep="NaN", - space=12, - float_format=None, - justify="right", - decimal=".", - quoting=None, - fixed_width=True, - leading_space=None, + values: Any, + digits: int = 7, + formatter: Optional[Callable] = None, + na_rep: str = "NaN", + space: Union[str, int] = 12, + float_format: Optional[float_format_type] = None, + justify: str = "right", + decimal: str = ".", + quoting: Optional[int] = None, + fixed_width: bool = True, + leading_space: Optional[bool] = None, ): self.values = values self.digits = digits @@ -1075,11 +1126,11 @@ def __init__( self.fixed_width = fixed_width self.leading_space = leading_space - def get_result(self): + def get_result(self) -> Union[ndarray, List[str]]: fmt_values = self._format_strings() return _make_fixed_width(fmt_values, self.justify) - def _format_strings(self): + def _format_strings(self) -> List[str]: if self.float_format is None: float_format = get_option("display.float_format") if float_format is None: @@ -1161,7 +1212,11 @@ def __init__(self, *args, **kwargs): self.formatter = self.float_format self.float_format = None - def _value_formatter(self, float_format=None, threshold=None): + def _value_formatter( + self, + float_format: Optional[float_format_type] = None, + threshold: Optional[Union[float, int]] = None, + ) -> Callable: """Returns a function to be applied on each value to format it """ @@ -1207,7 +1262,7 @@ def formatter(value): return formatter - def get_result_as_array(self): + def get_result_as_array(self) -> Union[ndarray, List[str]]: """ Returns the float values converted into strings using the parameters given at initialisation, as a numpy array @@ -1259,7 +1314,7 @@ def format_values_with(float_format): if self.fixed_width: float_format = partial( "{value: .{digits:d}f}".format, digits=self.digits - ) + ) # type: Optional[float_format_type] else: float_format = self.float_format else: @@ -1296,7 +1351,7 @@ def format_values_with(float_format): return formatted_values - def _format_strings(self): + def _format_strings(self) -> List[str]: # shortcut if self.formatter is not None: return [self.formatter(x) for x in self.values] @@ -1305,19 +1360,25 @@ def _format_strings(self): class IntArrayFormatter(GenericArrayFormatter): - def _format_strings(self): + def _format_strings(self) -> List[str]: formatter = self.formatter or (lambda x: "{x: d}".format(x=x)) fmt_values = [formatter(x) for x in self.values] return fmt_values class Datetime64Formatter(GenericArrayFormatter): - def __init__(self, values, nat_rep="NaT", date_format=None, **kwargs): + def __init__( + self, + values: Union[ndarray, "Series", DatetimeIndex, DatetimeArray], + nat_rep: str = "NaT", + date_format: None = None, + **kwargs + ): super().__init__(values, **kwargs) self.nat_rep = nat_rep self.date_format = date_format - def _format_strings(self): + def _format_strings(self) -> List[str]: """ we by definition have DO NOT have a TZ """ values = self.values @@ -1337,7 +1398,7 @@ def _format_strings(self): class ExtensionArrayFormatter(GenericArrayFormatter): - def _format_strings(self): + def _format_strings(self) -> List[str]: values = self.values if isinstance(values, (ABCIndexClass, ABCSeries)): values = values._values @@ -1363,7 +1424,11 @@ def _format_strings(self): return fmt_values -def format_percentiles(percentiles): +def format_percentiles( + percentiles: Union[ + ndarray, List[Union[int, float]], List[float], List[Union[str, float]] + ] +) -> List[str]: """ Outputs rounded and formatted percentiles. @@ -1429,7 +1494,7 @@ def format_percentiles(percentiles): return [i + "%" for i in out] -def _is_dates_only(values): +def _is_dates_only(values: Union[ndarray, DatetimeArray, Index, DatetimeIndex]) -> bool: # return a boolean if we are only dates (and don't have a timezone) assert values.ndim == 1 @@ -1448,7 +1513,11 @@ def _is_dates_only(values): return False -def _format_datetime64(x, tz=None, nat_rep="NaT"): +def _format_datetime64( + x: Union[NaTType, Timestamp], + tz: Optional[Union[tzfile, tzutc]] = None, + nat_rep: str = "NaT", +) -> str: if x is None or (is_scalar(x) and isna(x)): return nat_rep @@ -1461,7 +1530,9 @@ def _format_datetime64(x, tz=None, nat_rep="NaT"): return str(x) -def _format_datetime64_dateonly(x, nat_rep="NaT", date_format=None): +def _format_datetime64_dateonly( + x: Union[NaTType, Timestamp], nat_rep: str = "NaT", date_format: None = None +) -> str: if x is None or (is_scalar(x) and isna(x)): return nat_rep @@ -1474,7 +1545,9 @@ def _format_datetime64_dateonly(x, nat_rep="NaT", date_format=None): return x._date_repr -def _get_format_datetime64(is_dates_only, nat_rep="NaT", date_format=None): +def _get_format_datetime64( + is_dates_only: bool, nat_rep: str = "NaT", date_format: None = None +) -> Callable: if is_dates_only: return lambda x, tz=None: _format_datetime64_dateonly( @@ -1484,7 +1557,9 @@ def _get_format_datetime64(is_dates_only, nat_rep="NaT", date_format=None): return lambda x, tz=None: _format_datetime64(x, tz=tz, nat_rep=nat_rep) -def _get_format_datetime64_from_values(values, date_format): +def _get_format_datetime64_from_values( + values: Union[ndarray, DatetimeArray, DatetimeIndex], date_format: Optional[str] +) -> Optional[str]: """ given values and a date_format, return a string format """ if isinstance(values, np.ndarray) and values.ndim > 1: @@ -1499,7 +1574,7 @@ def _get_format_datetime64_from_values(values, date_format): class Datetime64TZFormatter(Datetime64Formatter): - def _format_strings(self): + def _format_strings(self) -> List[str]: """ we by definition have a TZ """ values = self.values.astype(object) @@ -1513,12 +1588,18 @@ def _format_strings(self): class Timedelta64Formatter(GenericArrayFormatter): - def __init__(self, values, nat_rep="NaT", box=False, **kwargs): + def __init__( + self, + values: Union[ndarray, TimedeltaIndex], + nat_rep: str = "NaT", + box: bool = False, + **kwargs + ): super().__init__(values, **kwargs) self.nat_rep = nat_rep self.box = box - def _format_strings(self): + def _format_strings(self) -> ndarray: formatter = self.formatter or _get_format_timedelta64( self.values, nat_rep=self.nat_rep, box=self.box ) @@ -1526,7 +1607,11 @@ def _format_strings(self): return fmt_values -def _get_format_timedelta64(values, nat_rep="NaT", box=False): +def _get_format_timedelta64( + values: Union[ndarray, TimedeltaIndex, TimedeltaArray], + nat_rep: str = "NaT", + box: bool = False, +) -> Callable: """ Return a formatter function for a range of timedeltas. These will all have the same format argument @@ -1567,7 +1652,12 @@ def _formatter(x): return _formatter -def _make_fixed_width(strings, justify="right", minimum=None, adj=None): +def _make_fixed_width( + strings: Union[ndarray, List[str]], + justify: str = "right", + minimum: Optional[int] = None, + adj: Optional[TextAdjustment] = None, +) -> Union[ndarray, List[str]]: if len(strings) == 0 or justify == "all": return strings @@ -1595,7 +1685,7 @@ def just(x): return result -def _trim_zeros_complex(str_complexes, na_rep="NaN"): +def _trim_zeros_complex(str_complexes: ndarray, na_rep: str = "NaN") -> List[str]: """ Separates the real and imaginary parts from the complex number, and executes the _trim_zeros_float method on each of those. @@ -1613,7 +1703,9 @@ def separate_and_trim(str_complex, na_rep): return ["".join(separate_and_trim(x, na_rep)) for x in str_complexes] -def _trim_zeros_float(str_floats, na_rep="NaN"): +def _trim_zeros_float( + str_floats: Union[ndarray, List[str]], na_rep: str = "NaN" +) -> List[str]: """ Trims zeros, leaving just one before the decimal points if need be. """ @@ -1637,7 +1729,7 @@ def _cond(values): return [x + "0" if x.endswith(".") and _is_number(x) else x for x in trimmed] -def _has_names(index): +def _has_names(index: Index) -> bool: if isinstance(index, ABCMultiIndex): return com._any_not_none(*index.names) else: @@ -1672,11 +1764,11 @@ class EngFormatter: 24: "Y", } - def __init__(self, accuracy=None, use_eng_prefix=False): + def __init__(self, accuracy: Optional[int] = None, use_eng_prefix: bool = False): self.accuracy = accuracy self.use_eng_prefix = use_eng_prefix - def __call__(self, num): + def __call__(self, num: Union[float64, int, float]) -> str: """ Formats a number in engineering notation, appending a letter representing the power of 1000 of the original number. Some examples: @@ -1743,7 +1835,7 @@ def __call__(self, num): return formatted -def set_eng_float_format(accuracy=3, use_eng_prefix=False): +def set_eng_float_format(accuracy: int = 3, use_eng_prefix: bool = False) -> None: """ Alter default behavior on how float is formatted in DataFrame. Format float in engineering format. By accuracy, we mean the number of @@ -1756,7 +1848,7 @@ def set_eng_float_format(accuracy=3, use_eng_prefix=False): set_option("display.column_space", max(12, accuracy + 9)) -def _binify(cols, line_width): +def _binify(cols: List[int32], line_width: Union[int32, int]) -> List[int]: adjoin_width = 1 bins = [] curr_width = 0 @@ -1776,7 +1868,9 @@ def _binify(cols, line_width): return bins -def get_level_lengths(levels, sentinel=""): +def get_level_lengths( + levels: Any, sentinel: Union[bool, object, str] = "" +) -> List[Dict[int, int]]: """For each index in each level the function returns lengths of indexes. Parameters @@ -1816,7 +1910,7 @@ def get_level_lengths(levels, sentinel=""): return result -def buffer_put_lines(buf, lines): +def buffer_put_lines(buf: TextIO, lines: List[str]) -> None: """ Appends lines to a buffer. diff --git a/pandas/io/formats/html.py b/pandas/io/formats/html.py index 91e90a78d87a7..19305126f4e5f 100644 --- a/pandas/io/formats/html.py +++ b/pandas/io/formats/html.py @@ -4,7 +4,7 @@ from collections import OrderedDict from textwrap import dedent -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast from pandas._config import get_option @@ -82,8 +82,9 @@ def row_levels(self) -> int: def _get_columns_formatted_values(self) -> Iterable: return self.columns + # https://github.com/python/mypy/issues/1237 @property - def is_truncated(self) -> bool: + def is_truncated(self) -> bool: # type: ignore return self.fmt.is_truncated @property @@ -458,6 +459,8 @@ def _write_hierarchical_rows( # Insert ... row and adjust idx_values and # level_lengths to take this into account. ins_row = self.fmt.tr_row_num + # cast here since if truncate_v is True, self.fmt.tr_row_num is not None + ins_row = cast(int, ins_row) inserted = False for lnum, records in enumerate(level_lengths): rec_new = {}