Skip to content

Commit dc5ea0e

Browse files
authored
Permit covariance of key type in read_csv converters argument (#450)
* Permit covariance of key type in read_csv converters argument * Allow Mapping inputs to all dict inputs to read_csv to permit value covariance, and add tests for na_values and parse_dates
1 parent 8902752 commit dc5ea0e

File tree

2 files changed

+133
-25
lines changed

2 files changed

+133
-25
lines changed

pandas-stubs/io/parsers/readers.pyi

+44-25
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
Any,
99
Callable,
1010
Literal,
11+
Mapping,
1112
Sequence,
1213
overload,
1314
)
@@ -48,23 +49,26 @@ def read_csv(
4849
| None = ...,
4950
dtype: DtypeArg | defaultdict | None = ...,
5051
engine: CSVEngine | None = ...,
51-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
52+
converters: Mapping[int | str, Callable[[str], Any]]
53+
| Mapping[int, Callable[[str], Any]]
54+
| Mapping[str, Callable[[str], Any]]
55+
| None = ...,
5256
true_values: list[str] = ...,
5357
false_values: list[str] = ...,
5458
skipinitialspace: bool = ...,
5559
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
5660
skipfooter: int = ...,
5761
nrows: int | None = ...,
58-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
62+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
5963
keep_default_na: bool = ...,
6064
na_filter: bool = ...,
6165
verbose: bool = ...,
6266
skip_blank_lines: bool = ...,
6367
parse_dates: bool
64-
| Sequence[int]
68+
| list[int]
6569
| list[str]
6670
| Sequence[Sequence[int]]
67-
| dict[str, Sequence[int]] = ...,
71+
| Mapping[str, Sequence[int | str]] = ...,
6872
infer_datetime_format: bool = ...,
6973
keep_date_col: bool = ...,
7074
date_parser: Callable = ...,
@@ -111,23 +115,26 @@ def read_csv(
111115
| None = ...,
112116
dtype: DtypeArg | defaultdict | None = ...,
113117
engine: CSVEngine | None = ...,
114-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
118+
converters: Mapping[int | str, Callable[[str], Any]]
119+
| Mapping[int, Callable[[str], Any]]
120+
| Mapping[str, Callable[[str], Any]]
121+
| None = ...,
115122
true_values: list[str] = ...,
116123
false_values: list[str] = ...,
117124
skipinitialspace: bool = ...,
118125
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
119126
skipfooter: int = ...,
120127
nrows: int | None = ...,
121-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
128+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
122129
keep_default_na: bool = ...,
123130
na_filter: bool = ...,
124131
verbose: bool = ...,
125132
skip_blank_lines: bool = ...,
126133
parse_dates: bool
127-
| Sequence[int]
134+
| list[int]
128135
| list[str]
129136
| Sequence[Sequence[int]]
130-
| dict[str, Sequence[int]] = ...,
137+
| Mapping[str, Sequence[int | str]] = ...,
131138
infer_datetime_format: bool = ...,
132139
keep_date_col: bool = ...,
133140
date_parser: Callable = ...,
@@ -174,23 +181,26 @@ def read_csv(
174181
| None = ...,
175182
dtype: DtypeArg | defaultdict | None = ...,
176183
engine: CSVEngine | None = ...,
177-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
184+
converters: Mapping[int | str, Callable[[str], Any]]
185+
| Mapping[int, Callable[[str], Any]]
186+
| Mapping[str, Callable[[str], Any]]
187+
| None = ...,
178188
true_values: list[str] = ...,
179189
false_values: list[str] = ...,
180190
skipinitialspace: bool = ...,
181191
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
182192
skipfooter: int = ...,
183193
nrows: int | None = ...,
184-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
194+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
185195
keep_default_na: bool = ...,
186196
na_filter: bool = ...,
187197
verbose: bool = ...,
188198
skip_blank_lines: bool = ...,
189199
parse_dates: bool
190-
| Sequence[int]
200+
| list[int]
191201
| list[str]
192202
| Sequence[Sequence[int]]
193-
| dict[str, Sequence[int]] = ...,
203+
| Mapping[str, Sequence[int | str]] = ...,
194204
infer_datetime_format: bool = ...,
195205
keep_date_col: bool = ...,
196206
date_parser: Callable = ...,
@@ -237,23 +247,26 @@ def read_table(
237247
| None = ...,
238248
dtype: DtypeArg | defaultdict | None = ...,
239249
engine: CSVEngine | None = ...,
240-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
250+
converters: Mapping[int | str, Callable[[str], Any]]
251+
| Mapping[int, Callable[[str], Any]]
252+
| Mapping[str, Callable[[str], Any]]
253+
| None = ...,
241254
true_values: list[str] = ...,
242255
false_values: list[str] = ...,
243256
skipinitialspace: bool = ...,
244257
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
245258
skipfooter: int = ...,
246259
nrows: int | None = ...,
247-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
260+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
248261
keep_default_na: bool = ...,
249262
na_filter: bool = ...,
250263
verbose: bool = ...,
251264
skip_blank_lines: bool = ...,
252265
parse_dates: bool
253-
| Sequence[int]
266+
| list[int]
254267
| list[str]
255268
| Sequence[Sequence[int]]
256-
| dict[str, Sequence[int]] = ...,
269+
| Mapping[str, Sequence[int | str]] = ...,
257270
infer_datetime_format: bool = ...,
258271
keep_date_col: bool = ...,
259272
date_parser: Callable = ...,
@@ -300,23 +313,26 @@ def read_table(
300313
| None = ...,
301314
dtype: DtypeArg | defaultdict | None = ...,
302315
engine: CSVEngine | None = ...,
303-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
316+
converters: Mapping[int | str, Callable[[str], Any]]
317+
| Mapping[int, Callable[[str], Any]]
318+
| Mapping[str, Callable[[str], Any]]
319+
| None = ...,
304320
true_values: list[str] = ...,
305321
false_values: list[str] = ...,
306322
skipinitialspace: bool = ...,
307323
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
308324
skipfooter: int = ...,
309325
nrows: int | None = ...,
310-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
326+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
311327
keep_default_na: bool = ...,
312328
na_filter: bool = ...,
313329
verbose: bool = ...,
314330
skip_blank_lines: bool = ...,
315331
parse_dates: bool
316-
| Sequence[int]
332+
| list[int]
317333
| list[str]
318334
| Sequence[Sequence[int]]
319-
| dict[str, Sequence[int]] = ...,
335+
| Mapping[str, Sequence[int | str]] = ...,
320336
infer_datetime_format: bool = ...,
321337
keep_date_col: bool = ...,
322338
date_parser: Callable = ...,
@@ -363,23 +379,26 @@ def read_table(
363379
| None = ...,
364380
dtype: DtypeArg | defaultdict | None = ...,
365381
engine: CSVEngine | None = ...,
366-
converters: dict[int | str, Callable[[str], Any]] | None = ...,
382+
converters: Mapping[int | str, Callable[[str], Any]]
383+
| Mapping[int, Callable[[str], Any]]
384+
| Mapping[str, Callable[[str], Any]]
385+
| None = ...,
367386
true_values: list[str] = ...,
368387
false_values: list[str] = ...,
369388
skipinitialspace: bool = ...,
370389
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
371390
skipfooter: int = ...,
372391
nrows: int | None = ...,
373-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
392+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
374393
keep_default_na: bool = ...,
375394
na_filter: bool = ...,
376395
verbose: bool = ...,
377396
skip_blank_lines: bool = ...,
378397
parse_dates: bool
379-
| Sequence[int]
398+
| list[int]
380399
| list[str]
381400
| Sequence[Sequence[int]]
382-
| dict[str, Sequence[int]] = ...,
401+
| Mapping[str, Sequence[int | str]] = ...,
383402
infer_datetime_format: bool = ...,
384403
keep_date_col: bool = ...,
385404
date_parser: Callable = ...,
@@ -443,7 +462,7 @@ def read_fwf(
443462

444463
class TextFileReader(abc.Iterator):
445464
engine: CSVEngine
446-
orig_options: dict[str, Any]
465+
orig_options: Mapping[str, Any]
447466
chunksize: int | None
448467
nrows: int | None
449468
squeeze: bool

tests/test_frame.py

+89
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
List,
2020
Mapping,
2121
Tuple,
22+
TypedDict,
2223
TypeVar,
2324
Union,
2425
)
@@ -1278,6 +1279,94 @@ def test_read_csv() -> None:
12781279
pd.DataFrame,
12791280
)
12801281

1282+
# Allow a variety of dict types for the converters parameter
1283+
converters1 = {"A": str, "B": str}
1284+
check(
1285+
assert_type(pd.read_csv(path, converters=converters1), pd.DataFrame),
1286+
pd.DataFrame,
1287+
)
1288+
converters2 = {"A": str, "B": float}
1289+
check(
1290+
assert_type(pd.read_csv(path, converters=converters2), pd.DataFrame),
1291+
pd.DataFrame,
1292+
)
1293+
converters3 = {0: str, 1: str}
1294+
check(
1295+
assert_type(pd.read_csv(path, converters=converters3), pd.DataFrame),
1296+
pd.DataFrame,
1297+
)
1298+
converters4 = {0: str, 1: float}
1299+
check(
1300+
assert_type(pd.read_csv(path, converters=converters4), pd.DataFrame),
1301+
pd.DataFrame,
1302+
)
1303+
converters5: dict[int | str, Callable[[str], Any]] = {
1304+
0: str,
1305+
"A": float,
1306+
}
1307+
check(
1308+
assert_type(pd.read_csv(path, converters=converters5), pd.DataFrame),
1309+
pd.DataFrame,
1310+
)
1311+
1312+
class ReadCsvKwargs(TypedDict):
1313+
converters: dict[int, Callable[[str], Any]]
1314+
1315+
read_csv_kwargs: ReadCsvKwargs = {"converters": {0: int}}
1316+
1317+
check(
1318+
assert_type(pd.read_csv(path, **read_csv_kwargs), pd.DataFrame),
1319+
pd.DataFrame,
1320+
)
1321+
1322+
# Check value covariance for various other parameters too (these only accept a str key)
1323+
na_values = {"A": ["1"], "B": ["1"]}
1324+
check(
1325+
assert_type(pd.read_csv(path, na_values=na_values), pd.DataFrame),
1326+
pd.DataFrame,
1327+
)
1328+
1329+
# There are several possible inputs for parse_dates
1330+
with ensure_clean() as path:
1331+
Path(path).write_text("Date,Year,Month,Day\n20221125,2022,11,25")
1332+
parse_dates_1 = ["Date"]
1333+
check(
1334+
assert_type(pd.read_csv(path, parse_dates=parse_dates_1), pd.DataFrame),
1335+
pd.DataFrame,
1336+
)
1337+
check(
1338+
assert_type(
1339+
pd.read_csv(path, index_col="Date", parse_dates=True), pd.DataFrame
1340+
),
1341+
pd.DataFrame,
1342+
)
1343+
parse_dates_2 = {"combined_date": ["Year", "Month", "Day"]}
1344+
check(
1345+
assert_type(pd.read_csv(path, parse_dates=parse_dates_2), pd.DataFrame),
1346+
pd.DataFrame,
1347+
)
1348+
parse_dates_3 = {"combined_date": [1, 2, 3]}
1349+
check(
1350+
assert_type(pd.read_csv(path, parse_dates=parse_dates_3), pd.DataFrame),
1351+
pd.DataFrame,
1352+
)
1353+
# MyPy calls this Dict[str, object] by default which necessitates the explicit annotation (Pyright does not)
1354+
parse_dates_4: dict[str, list[str | int]] = {"combined_date": [1, "Month", 3]}
1355+
check(
1356+
assert_type(pd.read_csv(path, parse_dates=parse_dates_4), pd.DataFrame),
1357+
pd.DataFrame,
1358+
)
1359+
parse_dates_5 = [2]
1360+
check(
1361+
assert_type(pd.read_csv(path, parse_dates=parse_dates_5), pd.DataFrame),
1362+
pd.DataFrame,
1363+
)
1364+
parse_dates_6 = [[1, 2], [1, 2, 3]]
1365+
check(
1366+
assert_type(pd.read_csv(path, parse_dates=parse_dates_6), pd.DataFrame),
1367+
pd.DataFrame,
1368+
)
1369+
12811370

12821371
def test_groupby_series_methods() -> None:
12831372
df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})

0 commit comments

Comments
 (0)