Skip to content

Commit 02ca9f6

Browse files
committed
Allow Mapping inputs to all dict inputs to read_csv to permit value covariance, and add tests for na_values and parse_dates
1 parent 3f86e05 commit 02ca9f6

File tree

2 files changed

+92
-43
lines changed

2 files changed

+92
-43
lines changed

pandas-stubs/io/parsers/readers.pyi

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import (
88
Any,
99
Callable,
1010
Literal,
11+
Mapping,
1112
Sequence,
1213
overload,
1314
)
@@ -48,26 +49,26 @@ def read_csv(
4849
| None = ...,
4950
dtype: DtypeArg | defaultdict | None = ...,
5051
engine: CSVEngine | None = ...,
51-
converters: dict[int | str, Callable[[str], Any]]
52-
| dict[int, Callable[[str], Any]]
53-
| dict[str, Callable[[str], Any]]
52+
converters: Mapping[int | str, Callable[[str], Any]]
53+
| Mapping[int, Callable[[str], Any]]
54+
| Mapping[str, Callable[[str], Any]]
5455
| None = ...,
5556
true_values: list[str] = ...,
5657
false_values: list[str] = ...,
5758
skipinitialspace: bool = ...,
5859
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
5960
skipfooter: int = ...,
6061
nrows: int | None = ...,
61-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
62+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
6263
keep_default_na: bool = ...,
6364
na_filter: bool = ...,
6465
verbose: bool = ...,
6566
skip_blank_lines: bool = ...,
6667
parse_dates: bool
67-
| Sequence[int]
68+
| list[int]
6869
| list[str]
6970
| Sequence[Sequence[int]]
70-
| dict[str, Sequence[int]] = ...,
71+
| Mapping[str, Sequence[int | str]] = ...,
7172
infer_datetime_format: bool = ...,
7273
keep_date_col: bool = ...,
7374
date_parser: Callable = ...,
@@ -114,26 +115,26 @@ def read_csv(
114115
| None = ...,
115116
dtype: DtypeArg | defaultdict | None = ...,
116117
engine: CSVEngine | None = ...,
117-
converters: dict[int | str, Callable[[str], Any]]
118-
| dict[int, Callable[[str], Any]]
119-
| dict[str, Callable[[str], Any]]
118+
converters: Mapping[int | str, Callable[[str], Any]]
119+
| Mapping[int, Callable[[str], Any]]
120+
| Mapping[str, Callable[[str], Any]]
120121
| None = ...,
121122
true_values: list[str] = ...,
122123
false_values: list[str] = ...,
123124
skipinitialspace: bool = ...,
124125
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
125126
skipfooter: int = ...,
126127
nrows: int | None = ...,
127-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
128+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
128129
keep_default_na: bool = ...,
129130
na_filter: bool = ...,
130131
verbose: bool = ...,
131132
skip_blank_lines: bool = ...,
132133
parse_dates: bool
133-
| Sequence[int]
134+
| list[int]
134135
| list[str]
135136
| Sequence[Sequence[int]]
136-
| dict[str, Sequence[int]] = ...,
137+
| Mapping[str, Sequence[int | str]] = ...,
137138
infer_datetime_format: bool = ...,
138139
keep_date_col: bool = ...,
139140
date_parser: Callable = ...,
@@ -180,26 +181,26 @@ def read_csv(
180181
| None = ...,
181182
dtype: DtypeArg | defaultdict | None = ...,
182183
engine: CSVEngine | None = ...,
183-
converters: dict[int | str, Callable[[str], Any]]
184-
| dict[int, Callable[[str], Any]]
185-
| dict[str, Callable[[str], Any]]
184+
converters: Mapping[int | str, Callable[[str], Any]]
185+
| Mapping[int, Callable[[str], Any]]
186+
| Mapping[str, Callable[[str], Any]]
186187
| None = ...,
187188
true_values: list[str] = ...,
188189
false_values: list[str] = ...,
189190
skipinitialspace: bool = ...,
190191
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
191192
skipfooter: int = ...,
192193
nrows: int | None = ...,
193-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
194+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
194195
keep_default_na: bool = ...,
195196
na_filter: bool = ...,
196197
verbose: bool = ...,
197198
skip_blank_lines: bool = ...,
198199
parse_dates: bool
199-
| Sequence[int]
200+
| list[int]
200201
| list[str]
201202
| Sequence[Sequence[int]]
202-
| dict[str, Sequence[int]] = ...,
203+
| Mapping[str, Sequence[int | str]] = ...,
203204
infer_datetime_format: bool = ...,
204205
keep_date_col: bool = ...,
205206
date_parser: Callable = ...,
@@ -246,26 +247,26 @@ def read_table(
246247
| None = ...,
247248
dtype: DtypeArg | defaultdict | None = ...,
248249
engine: CSVEngine | None = ...,
249-
converters: dict[int | str, Callable[[str], Any]]
250-
| dict[int, Callable[[str], Any]]
251-
| dict[str, Callable[[str], Any]]
250+
converters: Mapping[int | str, Callable[[str], Any]]
251+
| Mapping[int, Callable[[str], Any]]
252+
| Mapping[str, Callable[[str], Any]]
252253
| None = ...,
253254
true_values: list[str] = ...,
254255
false_values: list[str] = ...,
255256
skipinitialspace: bool = ...,
256257
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
257258
skipfooter: int = ...,
258259
nrows: int | None = ...,
259-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
260+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
260261
keep_default_na: bool = ...,
261262
na_filter: bool = ...,
262263
verbose: bool = ...,
263264
skip_blank_lines: bool = ...,
264265
parse_dates: bool
265-
| Sequence[int]
266+
| list[int]
266267
| list[str]
267268
| Sequence[Sequence[int]]
268-
| dict[str, Sequence[int]] = ...,
269+
| Mapping[str, Sequence[int | str]] = ...,
269270
infer_datetime_format: bool = ...,
270271
keep_date_col: bool = ...,
271272
date_parser: Callable = ...,
@@ -312,26 +313,26 @@ def read_table(
312313
| None = ...,
313314
dtype: DtypeArg | defaultdict | None = ...,
314315
engine: CSVEngine | None = ...,
315-
converters: dict[int | str, Callable[[str], Any]]
316-
| dict[int, Callable[[str], Any]]
317-
| dict[str, Callable[[str], Any]]
316+
converters: Mapping[int | str, Callable[[str], Any]]
317+
| Mapping[int, Callable[[str], Any]]
318+
| Mapping[str, Callable[[str], Any]]
318319
| None = ...,
319320
true_values: list[str] = ...,
320321
false_values: list[str] = ...,
321322
skipinitialspace: bool = ...,
322323
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
323324
skipfooter: int = ...,
324325
nrows: int | None = ...,
325-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
326+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
326327
keep_default_na: bool = ...,
327328
na_filter: bool = ...,
328329
verbose: bool = ...,
329330
skip_blank_lines: bool = ...,
330331
parse_dates: bool
331-
| Sequence[int]
332+
| list[int]
332333
| list[str]
333334
| Sequence[Sequence[int]]
334-
| dict[str, Sequence[int]] = ...,
335+
| Mapping[str, Sequence[int | str]] = ...,
335336
infer_datetime_format: bool = ...,
336337
keep_date_col: bool = ...,
337338
date_parser: Callable = ...,
@@ -378,26 +379,26 @@ def read_table(
378379
| None = ...,
379380
dtype: DtypeArg | defaultdict | None = ...,
380381
engine: CSVEngine | None = ...,
381-
converters: dict[int | str, Callable[[str], Any]]
382-
| dict[int, Callable[[str], Any]]
383-
| dict[str, Callable[[str], Any]]
382+
converters: Mapping[int | str, Callable[[str], Any]]
383+
| Mapping[int, Callable[[str], Any]]
384+
| Mapping[str, Callable[[str], Any]]
384385
| None = ...,
385386
true_values: list[str] = ...,
386387
false_values: list[str] = ...,
387388
skipinitialspace: bool = ...,
388389
skiprows: int | Sequence[int] | Callable[[int], bool] = ...,
389390
skipfooter: int = ...,
390391
nrows: int | None = ...,
391-
na_values: Sequence[str] | dict[str, Sequence[str]] = ...,
392+
na_values: Sequence[str] | Mapping[str, Sequence[str]] = ...,
392393
keep_default_na: bool = ...,
393394
na_filter: bool = ...,
394395
verbose: bool = ...,
395396
skip_blank_lines: bool = ...,
396397
parse_dates: bool
397-
| Sequence[int]
398+
| list[int]
398399
| list[str]
399400
| Sequence[Sequence[int]]
400-
| dict[str, Sequence[int]] = ...,
401+
| Mapping[str, Sequence[int | str]] = ...,
401402
infer_datetime_format: bool = ...,
402403
keep_date_col: bool = ...,
403404
date_parser: Callable = ...,
@@ -461,7 +462,7 @@ def read_fwf(
461462

462463
class TextFileReader(abc.Iterator):
463464
engine: CSVEngine
464-
orig_options: dict[str, Any]
465+
orig_options: Mapping[str, Any]
465466
chunksize: int | None
466467
nrows: int | None
467468
squeeze: bool

tests/test_frame.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,29 +1280,29 @@ def test_read_csv() -> None:
12801280
)
12811281

12821282
# Allow a variety of dict types for the converters parameter
1283-
converters1 = {"A": lambda x: str, "B": lambda x: str}
1283+
converters1 = {"A": str, "B": str}
12841284
check(
12851285
assert_type(pd.read_csv(path, converters=converters1), pd.DataFrame),
12861286
pd.DataFrame,
12871287
)
1288-
converters2 = {"A": lambda x: str, "B": lambda x: float}
1288+
converters2 = {"A": str, "B": float}
12891289
check(
12901290
assert_type(pd.read_csv(path, converters=converters2), pd.DataFrame),
12911291
pd.DataFrame,
12921292
)
1293-
converters3 = {0: lambda x: str, 1: lambda x: str}
1293+
converters3 = {0: str, 1: str}
12941294
check(
12951295
assert_type(pd.read_csv(path, converters=converters3), pd.DataFrame),
12961296
pd.DataFrame,
12971297
)
1298-
converters4 = {0: lambda x: str, 1: lambda x: float}
1298+
converters4 = {0: str, 1: float}
12991299
check(
13001300
assert_type(pd.read_csv(path, converters=converters4), pd.DataFrame),
13011301
pd.DataFrame,
13021302
)
13031303
converters5: dict[int | str, Callable[[str], Any]] = {
1304-
0: lambda x: str,
1305-
"A": lambda x: float,
1304+
0: str,
1305+
"A": float,
13061306
}
13071307
check(
13081308
assert_type(pd.read_csv(path, converters=converters5), pd.DataFrame),
@@ -1319,6 +1319,54 @@ class ReadCsvKwargs(TypedDict):
13191319
pd.DataFrame,
13201320
)
13211321

1322+
# Check value covariance for various other parameters too (these only accept a str key)
1323+
na_values = {"A": ["1"], "B": ["1"]}
1324+
check(
1325+
assert_type(pd.read_csv(path, na_values=na_values), pd.DataFrame),
1326+
pd.DataFrame,
1327+
)
1328+
1329+
# There are several possible inputs for parse_dates
1330+
with ensure_clean() as path:
1331+
Path(path).write_text("Date,Year,Month,Day\n20221125,2022,11,25")
1332+
parse_dates_1 = ["Date"]
1333+
check(
1334+
assert_type(pd.read_csv(path, parse_dates=parse_dates_1), pd.DataFrame),
1335+
pd.DataFrame,
1336+
)
1337+
check(
1338+
assert_type(
1339+
pd.read_csv(path, index_col="Date", parse_dates=True), pd.DataFrame
1340+
),
1341+
pd.DataFrame,
1342+
)
1343+
parse_dates_2 = {"combined_date": ["Year", "Month", "Day"]}
1344+
check(
1345+
assert_type(pd.read_csv(path, parse_dates=parse_dates_2), pd.DataFrame),
1346+
pd.DataFrame,
1347+
)
1348+
parse_dates_3 = {"combined_date": [1, 2, 3]}
1349+
check(
1350+
assert_type(pd.read_csv(path, parse_dates=parse_dates_3), pd.DataFrame),
1351+
pd.DataFrame,
1352+
)
1353+
# MyPy calls this Dict[str, object] by default which necessitates the explicit annotation (Pyright does not)
1354+
parse_dates_4: dict[str, list[str | int]] = {"combined_date": [1, "Month", 3]}
1355+
check(
1356+
assert_type(pd.read_csv(path, parse_dates=parse_dates_4), pd.DataFrame),
1357+
pd.DataFrame,
1358+
)
1359+
parse_dates_5 = [2]
1360+
check(
1361+
assert_type(pd.read_csv(path, parse_dates=parse_dates_5), pd.DataFrame),
1362+
pd.DataFrame,
1363+
)
1364+
parse_dates_6 = [[1, 2], [1, 2, 3]]
1365+
check(
1366+
assert_type(pd.read_csv(path, parse_dates=parse_dates_6), pd.DataFrame),
1367+
pd.DataFrame,
1368+
)
1369+
13221370

13231371
def test_groupby_series_methods() -> None:
13241372
df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})

0 commit comments

Comments
 (0)