Skip to content

Commit a53fdfe

Browse files
committed
support pandas 1.3.0 read_csv
1 parent adc8042 commit a53fdfe

File tree

1 file changed

+50
-39
lines changed
  • src/datasets/packaged_modules/csv

1 file changed

+50
-39
lines changed

src/datasets/packaged_modules/csv/csv.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# coding=utf-8
22

3+
import inspect
34
from dataclasses import dataclass
45
from typing import List, Optional, Union
56

@@ -11,6 +12,9 @@
1112

1213
logger = datasets.utils.logging.get_logger(__name__)
1314

15+
_PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS = ["names", "prefix"]
16+
_PANDAS_READ_CSV_DEPRECATED_PARAMETERS = ["warn_bad_lines", "error_bad_lines"]
17+
1418

1519
@dataclass
1620
class CsvConfig(datasets.BuilderConfig):
@@ -60,6 +64,51 @@ def __post_init__(self):
6064
if self.column_names is not None:
6165
self.names = self.column_names
6266

67+
@property
68+
def read_csv_kwargs(self):
69+
read_csv_kwargs = dict(
70+
sep=self.sep,
71+
header=self.header,
72+
names=self.names,
73+
index_col=self.index_col,
74+
usecols=self.usecols,
75+
prefix=self.prefix,
76+
mangle_dupe_cols=self.mangle_dupe_cols,
77+
engine=self.engine,
78+
true_values=self.true_values,
79+
false_values=self.false_values,
80+
skipinitialspace=self.skipinitialspace,
81+
skiprows=self.skiprows,
82+
nrows=self.nrows,
83+
na_values=self.na_values,
84+
keep_default_na=self.keep_default_na,
85+
na_filter=self.na_filter,
86+
verbose=self.verbose,
87+
skip_blank_lines=self.skip_blank_lines,
88+
thousands=self.thousands,
89+
decimal=self.decimal,
90+
lineterminator=self.lineterminator,
91+
quotechar=self.quotechar,
92+
quoting=self.quoting,
93+
escapechar=self.escapechar,
94+
comment=self.comment,
95+
encoding=self.encoding,
96+
dialect=self.dialect,
97+
error_bad_lines=self.error_bad_lines,
98+
warn_bad_lines=self.warn_bad_lines,
99+
skipfooter=self.skipfooter,
100+
doublequote=self.doublequote,
101+
memory_map=self.memory_map,
102+
float_precision=self.float_precision,
103+
chunksize=self.chunksize,
104+
)
105+
# some kwargs must not be passed if they don't have a default value
106+
# some others are deprecated and we can also not pass them if they are the default value
107+
for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
108+
if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter):
109+
del read_csv_kwargs[read_csv_parameter]
110+
return read_csv_kwargs
111+
63112

64113
class Csv(datasets.ArrowBasedBuilder):
65114
BUILDER_CONFIG_CLASS = CsvConfig
@@ -89,45 +138,7 @@ def _generate_tables(self, files):
89138
# dtype allows reading an int column as str
90139
dtype = {name: dtype.to_pandas_dtype() for name, dtype in zip(schema.names, schema.types)} if schema else None
91140
for file_idx, file in enumerate(files):
92-
csv_file_reader = pd.read_csv(
93-
file,
94-
iterator=True,
95-
dtype=dtype,
96-
sep=self.config.sep,
97-
header=self.config.header,
98-
names=self.config.names,
99-
index_col=self.config.index_col,
100-
usecols=self.config.usecols,
101-
prefix=self.config.prefix,
102-
mangle_dupe_cols=self.config.mangle_dupe_cols,
103-
engine=self.config.engine,
104-
true_values=self.config.true_values,
105-
false_values=self.config.false_values,
106-
skipinitialspace=self.config.skipinitialspace,
107-
skiprows=self.config.skiprows,
108-
nrows=self.config.nrows,
109-
na_values=self.config.na_values,
110-
keep_default_na=self.config.keep_default_na,
111-
na_filter=self.config.na_filter,
112-
verbose=self.config.verbose,
113-
skip_blank_lines=self.config.skip_blank_lines,
114-
thousands=self.config.thousands,
115-
decimal=self.config.decimal,
116-
lineterminator=self.config.lineterminator,
117-
quotechar=self.config.quotechar,
118-
quoting=self.config.quoting,
119-
escapechar=self.config.escapechar,
120-
comment=self.config.comment,
121-
encoding=self.config.encoding,
122-
dialect=self.config.dialect,
123-
error_bad_lines=self.config.error_bad_lines,
124-
warn_bad_lines=self.config.warn_bad_lines,
125-
skipfooter=self.config.skipfooter,
126-
doublequote=self.config.doublequote,
127-
memory_map=self.config.memory_map,
128-
float_precision=self.config.float_precision,
129-
chunksize=self.config.chunksize,
130-
)
141+
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
131142

132143
try:
133144
for batch_idx, df in enumerate(csv_file_reader):

0 commit comments

Comments
 (0)