diff --git a/pandas-stubs/io/parsers/readers.pyi b/pandas-stubs/io/parsers/readers.pyi index cf2d7d3f1..01f99e18c 100644 --- a/pandas-stubs/io/parsers/readers.pyi +++ b/pandas-stubs/io/parsers/readers.pyi @@ -8,6 +8,7 @@ from typing import ( Any, Callable, Literal, + Mapping, Sequence, overload, ) @@ -48,23 +49,26 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -111,23 +115,26 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -174,23 +181,26 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -237,23 +247,26 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -300,23 +313,26 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -363,23 +379,26 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -443,7 +462,7 @@ def read_fwf( class TextFileReader(abc.Iterator): engine: CSVEngine - orig_options: dict[str, Any] + orig_options: Mapping[str, Any] chunksize: int | None nrows: int | None squeeze: bool diff --git a/tests/test_frame.py b/tests/test_frame.py index 9e79bef2f..74f3b17b5 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -19,6 +19,7 @@ List, Mapping, Tuple, + TypedDict, TypeVar, Union, ) @@ -1278,6 +1279,94 @@ def test_read_csv() -> None: pd.DataFrame, ) + # Allow a variety of dict types for the converters parameter + converters1 = {"A": str, "B": str} + check( + assert_type(pd.read_csv(path, converters=converters1), pd.DataFrame), + pd.DataFrame, + ) + converters2 = {"A": str, "B": float} + check( + assert_type(pd.read_csv(path, converters=converters2), pd.DataFrame), + pd.DataFrame, + ) + converters3 = {0: str, 1: str} + check( + assert_type(pd.read_csv(path, converters=converters3), pd.DataFrame), + pd.DataFrame, + ) + converters4 = {0: str, 1: float} + check( + assert_type(pd.read_csv(path, converters=converters4), pd.DataFrame), + pd.DataFrame, + ) + converters5: dict[int | str, Callable[[str], Any]] = { + 0: str, + "A": float, + } + check( + assert_type(pd.read_csv(path, converters=converters5), pd.DataFrame), + pd.DataFrame, + ) + + class ReadCsvKwargs(TypedDict): + converters: dict[int, Callable[[str], Any]] + + read_csv_kwargs: ReadCsvKwargs = {"converters": {0: int}} + + check( + assert_type(pd.read_csv(path, **read_csv_kwargs), pd.DataFrame), + pd.DataFrame, + ) + + # Check value covariance for various other parameters too (these only accept a str key) + na_values = {"A": ["1"], "B": ["1"]} + check( + assert_type(pd.read_csv(path, na_values=na_values), pd.DataFrame), + pd.DataFrame, + ) + + # There are several possible inputs for parse_dates + with ensure_clean() as path: + Path(path).write_text("Date,Year,Month,Day\n20221125,2022,11,25") + parse_dates_1 = ["Date"] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_1), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + pd.read_csv(path, index_col="Date", parse_dates=True), pd.DataFrame + ), + pd.DataFrame, + ) + parse_dates_2 = {"combined_date": ["Year", "Month", "Day"]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_2), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_3 = {"combined_date": [1, 2, 3]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_3), pd.DataFrame), + pd.DataFrame, + ) + # MyPy calls this Dict[str, object] by default which necessitates the explicit annotation (Pyright does not) + parse_dates_4: dict[str, list[str | int]] = {"combined_date": [1, "Month", 3]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_4), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_5 = [2] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_5), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_6 = [[1, 2], [1, 2, 3]] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_6), pd.DataFrame), + pd.DataFrame, + ) + def test_groupby_series_methods() -> None: df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})