diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 8eab623a2b5f7..85cf59dc7135b 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -73,6 +73,8 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ + +- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`) - :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`) - DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - diff --git a/pandas/io/parsers/arrow_parser_wrapper.py b/pandas/io/parsers/arrow_parser_wrapper.py index bb6bcd3c4d6a0..765a4ffcd2cb9 100644 --- a/pandas/io/parsers/arrow_parser_wrapper.py +++ b/pandas/io/parsers/arrow_parser_wrapper.py @@ -1,11 +1,17 @@ from __future__ import annotations from typing import TYPE_CHECKING +import warnings from pandas._config import using_pyarrow_string_dtype from pandas._libs import lib from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + ParserError, + ParserWarning, +) +from pandas.util._exceptions import find_stack_level from pandas.core.dtypes.inference import is_integer @@ -85,6 +91,30 @@ def _get_pyarrow_options(self) -> None: and option_name in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines") } + + on_bad_lines = self.kwds.get("on_bad_lines") + if on_bad_lines is not None: + if callable(on_bad_lines): + self.parse_options["invalid_row_handler"] = on_bad_lines + elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR: + self.parse_options[ + "invalid_row_handler" + ] = None # PyArrow raises an exception by default + elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN: + + def handle_warning(invalid_row): + warnings.warn( + f"Expected {invalid_row.expected_columns} columns, but found " + f"{invalid_row.actual_columns}: {invalid_row.text}", + ParserWarning, + stacklevel=find_stack_level(), + ) + return "skip" + + self.parse_options["invalid_row_handler"] = handle_warning + elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP: + self.parse_options["invalid_row_handler"] = lambda _: "skip" + self.convert_options = { option_name: option_value for option_name, option_value in self.kwds.items() @@ -190,12 +220,15 @@ def read(self) -> DataFrame: pyarrow_csv = import_optional_dependency("pyarrow.csv") self._get_pyarrow_options() - table = pyarrow_csv.read_csv( - self.src, - read_options=pyarrow_csv.ReadOptions(**self.read_options), - parse_options=pyarrow_csv.ParseOptions(**self.parse_options), - convert_options=pyarrow_csv.ConvertOptions(**self.convert_options), - ) + try: + table = pyarrow_csv.read_csv( + self.src, + read_options=pyarrow_csv.ReadOptions(**self.read_options), + parse_options=pyarrow_csv.ParseOptions(**self.parse_options), + convert_options=pyarrow_csv.ConvertOptions(**self.convert_options), + ) + except pa.ArrowInvalid as e: + raise ParserError(e) from e dtype_backend = self.kwds["dtype_backend"] diff --git a/pandas/io/parsers/readers.py b/pandas/io/parsers/readers.py index acf35ebd6afe5..6ce6ac71b1ddd 100644 --- a/pandas/io/parsers/readers.py +++ b/pandas/io/parsers/readers.py @@ -401,6 +401,13 @@ expected, a ``ParserWarning`` will be emitted while dropping extra elements. Only supported when ``engine='python'`` + .. versionchanged:: 2.2.0 + + - Callable, function with signature + as described in `pyarrow documentation + _` when ``engine='pyarrow'`` + delim_whitespace : bool, default False Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option @@ -494,7 +501,6 @@ class _Fwf_Defaults(TypedDict): "thousands", "memory_map", "dialect", - "on_bad_lines", "delim_whitespace", "quoting", "lineterminator", @@ -2142,9 +2148,10 @@ def _refine_defaults_read( elif on_bad_lines == "skip": kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP elif callable(on_bad_lines): - if engine != "python": + if engine not in ["python", "pyarrow"]: raise ValueError( - "on_bad_line can only be a callable function if engine='python'" + "on_bad_line can only be a callable function " + "if engine='python' or 'pyarrow'" ) kwds["on_bad_lines"] = on_bad_lines else: diff --git a/pandas/tests/io/parser/common/test_read_errors.py b/pandas/tests/io/parser/common/test_read_errors.py index 0c5a2e0d04e5a..4e82dca83e2d0 100644 --- a/pandas/tests/io/parser/common/test_read_errors.py +++ b/pandas/tests/io/parser/common/test_read_errors.py @@ -1,5 +1,5 @@ """ -Tests that work on both the Python and C engines but do not have a +Tests that work on the Python, C and PyArrow engines but do not have a specific classification into the other test modules. """ import codecs @@ -21,7 +21,8 @@ from pandas import DataFrame import pandas._testing as tm -pytestmark = pytest.mark.usefixtures("pyarrow_skip") +xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail") +skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip") def test_empty_decimal_marker(all_parsers): @@ -33,10 +34,17 @@ def test_empty_decimal_marker(all_parsers): msg = "Only length-1 decimal markers supported" parser = all_parsers + if parser.engine == "pyarrow": + msg = ( + "only single character unicode strings can be " + "converted to Py_UCS4, got length 0" + ) + with pytest.raises(ValueError, match=msg): parser.read_csv(StringIO(data), decimal="") +@skip_pyarrow def test_bad_stream_exception(all_parsers, csv_dir_path): # see gh-13652 # @@ -57,6 +65,7 @@ def test_bad_stream_exception(all_parsers, csv_dir_path): parser.read_csv(stream) +@skip_pyarrow def test_malformed(all_parsers): # see gh-6607 parser = all_parsers @@ -71,6 +80,7 @@ def test_malformed(all_parsers): parser.read_csv(StringIO(data), header=1, comment="#") +@skip_pyarrow @pytest.mark.parametrize("nrows", [5, 3, None]) def test_malformed_chunks(all_parsers, nrows): data = """ignore @@ -90,6 +100,7 @@ def test_malformed_chunks(all_parsers, nrows): reader.read(nrows) +@skip_pyarrow def test_catch_too_many_names(all_parsers): # see gh-5156 data = """\ @@ -109,6 +120,7 @@ def test_catch_too_many_names(all_parsers): parser.read_csv(StringIO(data), header=0, names=["a", "b", "c", "d"]) +@skip_pyarrow @pytest.mark.parametrize("nrows", [0, 1, 2, 3, 4, 5]) def test_raise_on_no_columns(all_parsers, nrows): parser = all_parsers @@ -147,6 +159,10 @@ def test_error_bad_lines(all_parsers): data = "a\n1\n1,2,3\n4\n5,6,7" msg = "Expected 1 fields in line 3, saw 3" + + if parser.engine == "pyarrow": + msg = "CSV parse error: Expected 1 columns, got 3: 1,2,3" + with pytest.raises(ParserError, match=msg): parser.read_csv(StringIO(data), on_bad_lines="error") @@ -156,9 +172,13 @@ def test_warn_bad_lines(all_parsers): parser = all_parsers data = "a\n1\n1,2,3\n4\n5,6,7" expected = DataFrame({"a": [1, 4]}) + match_msg = "Skipping line" + + if parser.engine == "pyarrow": + match_msg = "Expected 1 columns, but found 3: 1,2,3" with tm.assert_produces_warning( - ParserWarning, match="Skipping line", check_stacklevel=False + ParserWarning, match=match_msg, check_stacklevel=False ): result = parser.read_csv(StringIO(data), on_bad_lines="warn") tm.assert_frame_equal(result, expected) @@ -174,10 +194,14 @@ def test_read_csv_wrong_num_columns(all_parsers): parser = all_parsers msg = "Expected 6 fields in line 3, saw 7" + if parser.engine == "pyarrow": + msg = "Expected 6 columns, got 7: 6,7,8,9,10,11,12" + with pytest.raises(ParserError, match=msg): parser.read_csv(StringIO(data)) +@skip_pyarrow def test_null_byte_char(request, all_parsers): # see gh-2741 data = "\x00,foo" @@ -200,6 +224,7 @@ def test_null_byte_char(request, all_parsers): parser.read_csv(StringIO(data), names=names) +@skip_pyarrow @pytest.mark.filterwarnings("always::ResourceWarning") def test_open_file(request, all_parsers): # GH 39024 @@ -238,6 +263,8 @@ def test_bad_header_uniform_error(all_parsers): "Could not construct index. Requested to use 1 " "number of columns, but 3 left to parse." ) + elif parser.engine == "pyarrow": + msg = "CSV parse error: Expected 1 columns, got 4: col1,col2,col3,col4" with pytest.raises(ParserError, match=msg): parser.read_csv(StringIO(data), index_col=0, on_bad_lines="error") @@ -253,9 +280,13 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers): a,b """ expected = DataFrame({"1": "a", "2": ["b"] * 2}) + match_msg = "Skipping line" + + if parser.engine == "pyarrow": + match_msg = "Expected 2 columns, but found 3: a,b,c" with tm.assert_produces_warning( - ParserWarning, match="Skipping line", check_stacklevel=False + ParserWarning, match=match_msg, check_stacklevel=False ): result = parser.read_csv(StringIO(data), on_bad_lines="warn") tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/io/parser/test_unsupported.py b/pandas/tests/io/parser/test_unsupported.py index ac171568187cd..b489c09e917af 100644 --- a/pandas/tests/io/parser/test_unsupported.py +++ b/pandas/tests/io/parser/test_unsupported.py @@ -151,13 +151,17 @@ def test_pyarrow_engine(self): with pytest.raises(ValueError, match=msg): read_csv(StringIO(data), engine="pyarrow", **kwargs) - def test_on_bad_lines_callable_python_only(self, all_parsers): + def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers): # GH 5686 + # GH 54643 sio = StringIO("a,b\n1,2") bad_lines_func = lambda x: x parser = all_parsers - if all_parsers.engine != "python": - msg = "on_bad_line can only be a callable function if engine='python'" + if all_parsers.engine not in ["python", "pyarrow"]: + msg = ( + "on_bad_line can only be a callable " + "function if engine='python' or 'pyarrow'" + ) with pytest.raises(ValueError, match=msg): parser.read_csv(sio, on_bad_lines=bad_lines_func) else: