From 3f86e053e19301a019e93061bed41e35fe02eeb6 Mon Sep 17 00:00:00 2001 From: Siddhartha Gandhi Date: Fri, 25 Nov 2022 20:23:12 -0500 Subject: [PATCH 1/2] Permit covariance of key type in read_csv converters argument --- pandas-stubs/io/parsers/readers.pyi | 30 ++++++++++++++++----- tests/test_frame.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/io/parsers/readers.pyi b/pandas-stubs/io/parsers/readers.pyi index cf2d7d3f1..e1daa2ea1 100644 --- a/pandas-stubs/io/parsers/readers.pyi +++ b/pandas-stubs/io/parsers/readers.pyi @@ -48,7 +48,10 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., @@ -111,7 +114,10 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., @@ -174,7 +180,10 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., @@ -237,7 +246,10 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., @@ -300,7 +312,10 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., @@ -363,7 +378,10 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] | None = ..., + converters: dict[int | str, Callable[[str], Any]] + | dict[int, Callable[[str], Any]] + | dict[str, Callable[[str], Any]] + | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., skipinitialspace: bool = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index 9e79bef2f..48a0683e9 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -19,6 +19,7 @@ List, Mapping, Tuple, + TypedDict, TypeVar, Union, ) @@ -1278,6 +1279,46 @@ def test_read_csv() -> None: pd.DataFrame, ) + # Allow a variety of dict types for the converters parameter + converters1 = {"A": lambda x: str, "B": lambda x: str} + check( + assert_type(pd.read_csv(path, converters=converters1), pd.DataFrame), + pd.DataFrame, + ) + converters2 = {"A": lambda x: str, "B": lambda x: float} + check( + assert_type(pd.read_csv(path, converters=converters2), pd.DataFrame), + pd.DataFrame, + ) + converters3 = {0: lambda x: str, 1: lambda x: str} + check( + assert_type(pd.read_csv(path, converters=converters3), pd.DataFrame), + pd.DataFrame, + ) + converters4 = {0: lambda x: str, 1: lambda x: float} + check( + assert_type(pd.read_csv(path, converters=converters4), pd.DataFrame), + pd.DataFrame, + ) + converters5: dict[int | str, Callable[[str], Any]] = { + 0: lambda x: str, + "A": lambda x: float, + } + check( + assert_type(pd.read_csv(path, converters=converters5), pd.DataFrame), + pd.DataFrame, + ) + + class ReadCsvKwargs(TypedDict): + converters: dict[int, Callable[[str], Any]] + + read_csv_kwargs: ReadCsvKwargs = {"converters": {0: int}} + + check( + assert_type(pd.read_csv(path, **read_csv_kwargs), pd.DataFrame), + pd.DataFrame, + ) + def test_groupby_series_methods() -> None: df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]}) From 02ca9f6f02a33b93b6dc1762c49faf108c674075 Mon Sep 17 00:00:00 2001 From: Siddhartha Gandhi Date: Sat, 26 Nov 2022 11:46:35 -0500 Subject: [PATCH 2/2] Allow Mapping inputs to all dict inputs to read_csv to permit value covariance, and add tests for na_values and parse_dates --- pandas-stubs/io/parsers/readers.pyi | 75 +++++++++++++++-------------- tests/test_frame.py | 60 ++++++++++++++++++++--- 2 files changed, 92 insertions(+), 43 deletions(-) diff --git a/pandas-stubs/io/parsers/readers.pyi b/pandas-stubs/io/parsers/readers.pyi index e1daa2ea1..01f99e18c 100644 --- a/pandas-stubs/io/parsers/readers.pyi +++ b/pandas-stubs/io/parsers/readers.pyi @@ -8,6 +8,7 @@ from typing import ( Any, Callable, Literal, + Mapping, Sequence, overload, ) @@ -48,9 +49,9 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -58,16 +59,16 @@ def read_csv( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -114,9 +115,9 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -124,16 +125,16 @@ def read_csv( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -180,9 +181,9 @@ def read_csv( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -190,16 +191,16 @@ def read_csv( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -246,9 +247,9 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -256,16 +257,16 @@ def read_table( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -312,9 +313,9 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -322,16 +323,16 @@ def read_table( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -378,9 +379,9 @@ def read_table( | None = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., - converters: dict[int | str, Callable[[str], Any]] - | dict[int, Callable[[str], Any]] - | dict[str, Callable[[str], Any]] + converters: Mapping[int | str, Callable[[str], Any]] + | Mapping[int, Callable[[str], Any]] + | Mapping[str, Callable[[str], Any]] | None = ..., true_values: list[str] = ..., false_values: list[str] = ..., @@ -388,16 +389,16 @@ def read_table( skiprows: int | Sequence[int] | Callable[[int], bool] = ..., skipfooter: int = ..., nrows: int | None = ..., - na_values: Sequence[str] | dict[str, Sequence[str]] = ..., + na_values: Sequence[str] | Mapping[str, Sequence[str]] = ..., keep_default_na: bool = ..., na_filter: bool = ..., verbose: bool = ..., skip_blank_lines: bool = ..., parse_dates: bool - | Sequence[int] + | list[int] | list[str] | Sequence[Sequence[int]] - | dict[str, Sequence[int]] = ..., + | Mapping[str, Sequence[int | str]] = ..., infer_datetime_format: bool = ..., keep_date_col: bool = ..., date_parser: Callable = ..., @@ -461,7 +462,7 @@ def read_fwf( class TextFileReader(abc.Iterator): engine: CSVEngine - orig_options: dict[str, Any] + orig_options: Mapping[str, Any] chunksize: int | None nrows: int | None squeeze: bool diff --git a/tests/test_frame.py b/tests/test_frame.py index 48a0683e9..74f3b17b5 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1280,29 +1280,29 @@ def test_read_csv() -> None: ) # Allow a variety of dict types for the converters parameter - converters1 = {"A": lambda x: str, "B": lambda x: str} + converters1 = {"A": str, "B": str} check( assert_type(pd.read_csv(path, converters=converters1), pd.DataFrame), pd.DataFrame, ) - converters2 = {"A": lambda x: str, "B": lambda x: float} + converters2 = {"A": str, "B": float} check( assert_type(pd.read_csv(path, converters=converters2), pd.DataFrame), pd.DataFrame, ) - converters3 = {0: lambda x: str, 1: lambda x: str} + converters3 = {0: str, 1: str} check( assert_type(pd.read_csv(path, converters=converters3), pd.DataFrame), pd.DataFrame, ) - converters4 = {0: lambda x: str, 1: lambda x: float} + converters4 = {0: str, 1: float} check( assert_type(pd.read_csv(path, converters=converters4), pd.DataFrame), pd.DataFrame, ) converters5: dict[int | str, Callable[[str], Any]] = { - 0: lambda x: str, - "A": lambda x: float, + 0: str, + "A": float, } check( assert_type(pd.read_csv(path, converters=converters5), pd.DataFrame), @@ -1319,6 +1319,54 @@ class ReadCsvKwargs(TypedDict): pd.DataFrame, ) + # Check value covariance for various other parameters too (these only accept a str key) + na_values = {"A": ["1"], "B": ["1"]} + check( + assert_type(pd.read_csv(path, na_values=na_values), pd.DataFrame), + pd.DataFrame, + ) + + # There are several possible inputs for parse_dates + with ensure_clean() as path: + Path(path).write_text("Date,Year,Month,Day\n20221125,2022,11,25") + parse_dates_1 = ["Date"] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_1), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + pd.read_csv(path, index_col="Date", parse_dates=True), pd.DataFrame + ), + pd.DataFrame, + ) + parse_dates_2 = {"combined_date": ["Year", "Month", "Day"]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_2), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_3 = {"combined_date": [1, 2, 3]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_3), pd.DataFrame), + pd.DataFrame, + ) + # MyPy calls this Dict[str, object] by default which necessitates the explicit annotation (Pyright does not) + parse_dates_4: dict[str, list[str | int]] = {"combined_date": [1, "Month", 3]} + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_4), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_5 = [2] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_5), pd.DataFrame), + pd.DataFrame, + ) + parse_dates_6 = [[1, 2], [1, 2, 3]] + check( + assert_type(pd.read_csv(path, parse_dates=parse_dates_6), pd.DataFrame), + pd.DataFrame, + ) + def test_groupby_series_methods() -> None: df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})