|
1 | 1 | # coding=utf-8
|
2 | 2 |
|
| 3 | +import inspect |
3 | 4 | from dataclasses import dataclass
|
4 | 5 | from typing import List, Optional, Union
|
5 | 6 |
|
|
11 | 12 |
|
12 | 13 | logger = datasets.utils.logging.get_logger(__name__)
|
13 | 14 |
|
| 15 | +_PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS = ["names", "prefix"] |
| 16 | +_PANDAS_READ_CSV_DEPRECATED_PARAMETERS = ["warn_bad_lines", "error_bad_lines"] |
| 17 | + |
14 | 18 |
|
15 | 19 | @dataclass
|
16 | 20 | class CsvConfig(datasets.BuilderConfig):
|
@@ -60,6 +64,51 @@ def __post_init__(self):
|
60 | 64 | if self.column_names is not None:
|
61 | 65 | self.names = self.column_names
|
62 | 66 |
|
| 67 | + @property |
| 68 | + def read_csv_kwargs(self): |
| 69 | + read_csv_kwargs = dict( |
| 70 | + sep=self.sep, |
| 71 | + header=self.header, |
| 72 | + names=self.names, |
| 73 | + index_col=self.index_col, |
| 74 | + usecols=self.usecols, |
| 75 | + prefix=self.prefix, |
| 76 | + mangle_dupe_cols=self.mangle_dupe_cols, |
| 77 | + engine=self.engine, |
| 78 | + true_values=self.true_values, |
| 79 | + false_values=self.false_values, |
| 80 | + skipinitialspace=self.skipinitialspace, |
| 81 | + skiprows=self.skiprows, |
| 82 | + nrows=self.nrows, |
| 83 | + na_values=self.na_values, |
| 84 | + keep_default_na=self.keep_default_na, |
| 85 | + na_filter=self.na_filter, |
| 86 | + verbose=self.verbose, |
| 87 | + skip_blank_lines=self.skip_blank_lines, |
| 88 | + thousands=self.thousands, |
| 89 | + decimal=self.decimal, |
| 90 | + lineterminator=self.lineterminator, |
| 91 | + quotechar=self.quotechar, |
| 92 | + quoting=self.quoting, |
| 93 | + escapechar=self.escapechar, |
| 94 | + comment=self.comment, |
| 95 | + encoding=self.encoding, |
| 96 | + dialect=self.dialect, |
| 97 | + error_bad_lines=self.error_bad_lines, |
| 98 | + warn_bad_lines=self.warn_bad_lines, |
| 99 | + skipfooter=self.skipfooter, |
| 100 | + doublequote=self.doublequote, |
| 101 | + memory_map=self.memory_map, |
| 102 | + float_precision=self.float_precision, |
| 103 | + chunksize=self.chunksize, |
| 104 | + ) |
| 105 | + # some kwargs must not be passed if they don't have a default value |
| 106 | + # some others are deprecated and we can also not pass them if they are the default value |
| 107 | + for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: |
| 108 | + if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter): |
| 109 | + del read_csv_kwargs[read_csv_parameter] |
| 110 | + return read_csv_kwargs |
| 111 | + |
63 | 112 |
|
64 | 113 | class Csv(datasets.ArrowBasedBuilder):
|
65 | 114 | BUILDER_CONFIG_CLASS = CsvConfig
|
@@ -89,45 +138,7 @@ def _generate_tables(self, files):
|
89 | 138 | # dtype allows reading an int column as str
|
90 | 139 | dtype = {name: dtype.to_pandas_dtype() for name, dtype in zip(schema.names, schema.types)} if schema else None
|
91 | 140 | for file_idx, file in enumerate(files):
|
92 |
| - csv_file_reader = pd.read_csv( |
93 |
| - file, |
94 |
| - iterator=True, |
95 |
| - dtype=dtype, |
96 |
| - sep=self.config.sep, |
97 |
| - header=self.config.header, |
98 |
| - names=self.config.names, |
99 |
| - index_col=self.config.index_col, |
100 |
| - usecols=self.config.usecols, |
101 |
| - prefix=self.config.prefix, |
102 |
| - mangle_dupe_cols=self.config.mangle_dupe_cols, |
103 |
| - engine=self.config.engine, |
104 |
| - true_values=self.config.true_values, |
105 |
| - false_values=self.config.false_values, |
106 |
| - skipinitialspace=self.config.skipinitialspace, |
107 |
| - skiprows=self.config.skiprows, |
108 |
| - nrows=self.config.nrows, |
109 |
| - na_values=self.config.na_values, |
110 |
| - keep_default_na=self.config.keep_default_na, |
111 |
| - na_filter=self.config.na_filter, |
112 |
| - verbose=self.config.verbose, |
113 |
| - skip_blank_lines=self.config.skip_blank_lines, |
114 |
| - thousands=self.config.thousands, |
115 |
| - decimal=self.config.decimal, |
116 |
| - lineterminator=self.config.lineterminator, |
117 |
| - quotechar=self.config.quotechar, |
118 |
| - quoting=self.config.quoting, |
119 |
| - escapechar=self.config.escapechar, |
120 |
| - comment=self.config.comment, |
121 |
| - encoding=self.config.encoding, |
122 |
| - dialect=self.config.dialect, |
123 |
| - error_bad_lines=self.config.error_bad_lines, |
124 |
| - warn_bad_lines=self.config.warn_bad_lines, |
125 |
| - skipfooter=self.config.skipfooter, |
126 |
| - doublequote=self.config.doublequote, |
127 |
| - memory_map=self.config.memory_map, |
128 |
| - float_precision=self.config.float_precision, |
129 |
| - chunksize=self.config.chunksize, |
130 |
| - ) |
| 141 | + csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs) |
131 | 142 |
|
132 | 143 | try:
|
133 | 144 | for batch_idx, df in enumerate(csv_file_reader):
|
|
0 commit comments