diff --git a/pandas/_testing/_warnings.py b/pandas/_testing/_warnings.py index d9516077788c8..cd2e2b4141ffd 100644 --- a/pandas/_testing/_warnings.py +++ b/pandas/_testing/_warnings.py @@ -11,6 +11,7 @@ from typing import ( TYPE_CHECKING, Literal, + Union, cast, ) import warnings @@ -32,7 +33,8 @@ def assert_produces_warning( ] = "always", check_stacklevel: bool = True, raise_on_extra_warnings: bool = True, - match: str | None = None, + match: str | tuple[str | None, ...] | None = None, + must_find_all_warnings: bool = True, ) -> Generator[list[warnings.WarningMessage], None, None]: """ Context manager for running code expected to either raise a specific warning, @@ -68,8 +70,15 @@ class for all warnings. To raise multiple types of exceptions, raise_on_extra_warnings : bool, default True Whether extra warnings not of the type `expected_warning` should cause the test to fail. - match : str, optional - Match warning message. + match : {str, tuple[str, ...]}, optional + Match warning message. If it's a tuple, it has to be the size of + `expected_warning`. If additionally `must_find_all_warnings` is + True, each expected warning's message gets matched with a respective + match. Otherwise, multiple values get treated as an alternative. + must_find_all_warnings : bool, default True + If True and `expected_warning` is a tuple, each expected warning + type must get encountered. Otherwise, even one expected warning + results in success. Examples -------- @@ -97,13 +106,35 @@ class for all warnings. To raise multiple types of exceptions, yield w finally: if expected_warning: - expected_warning = cast(type[Warning], expected_warning) - _assert_caught_expected_warning( - caught_warnings=w, - expected_warning=expected_warning, - match=match, - check_stacklevel=check_stacklevel, - ) + if isinstance(expected_warning, tuple) and must_find_all_warnings: + match = ( + match + if isinstance(match, tuple) + else (match,) * len(expected_warning) + ) + for warning_type, warning_match in zip(expected_warning, match): + _assert_caught_expected_warnings( + caught_warnings=w, + expected_warning=warning_type, + match=warning_match, + check_stacklevel=check_stacklevel, + ) + else: + expected_warning = cast( + Union[type[Warning], tuple[type[Warning], ...]], + expected_warning, + ) + match = ( + "|".join(m for m in match if m) + if isinstance(match, tuple) + else match + ) + _assert_caught_expected_warnings( + caught_warnings=w, + expected_warning=expected_warning, + match=match, + check_stacklevel=check_stacklevel, + ) if raise_on_extra_warnings: _assert_caught_no_extra_warnings( caught_warnings=w, @@ -123,10 +154,10 @@ def maybe_produces_warning( return nullcontext() -def _assert_caught_expected_warning( +def _assert_caught_expected_warnings( *, caught_warnings: Sequence[warnings.WarningMessage], - expected_warning: type[Warning], + expected_warning: type[Warning] | tuple[type[Warning], ...], match: str | None, check_stacklevel: bool, ) -> None: @@ -134,6 +165,11 @@ def _assert_caught_expected_warning( saw_warning = False matched_message = False unmatched_messages = [] + warning_name = ( + tuple(x.__name__ for x in expected_warning) + if isinstance(expected_warning, tuple) + else expected_warning.__name__ + ) for actual_warning in caught_warnings: if issubclass(actual_warning.category, expected_warning): @@ -149,13 +185,11 @@ def _assert_caught_expected_warning( unmatched_messages.append(actual_warning.message) if not saw_warning: - raise AssertionError( - f"Did not see expected warning of class {expected_warning.__name__!r}" - ) + raise AssertionError(f"Did not see expected warning of class {warning_name!r}") if match and not matched_message: raise AssertionError( - f"Did not see warning {expected_warning.__name__!r} " + f"Did not see warning {warning_name!r} " f"matching '{match}'. The emitted warning messages are " f"{unmatched_messages}" ) diff --git a/pandas/_testing/contexts.py b/pandas/_testing/contexts.py index b986e03e25815..7ebed8857f0af 100644 --- a/pandas/_testing/contexts.py +++ b/pandas/_testing/contexts.py @@ -173,7 +173,7 @@ def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=() elif PYPY and extra_warnings: return assert_produces_warning( extra_warnings, - match="|".join(extra_match), + match=extra_match, ) else: if using_copy_on_write(): @@ -190,5 +190,5 @@ def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=() warning = (warning, *extra_warnings) # type: ignore[assignment] return assert_produces_warning( warning, - match="|".join((match, *extra_match)), + match=(match, *extra_match), ) diff --git a/pandas/tests/indexing/test_chaining_and_caching.py b/pandas/tests/indexing/test_chaining_and_caching.py index 2a2772d1b3453..b28c3cba7d310 100644 --- a/pandas/tests/indexing/test_chaining_and_caching.py +++ b/pandas/tests/indexing/test_chaining_and_caching.py @@ -284,7 +284,9 @@ def test_detect_chained_assignment_changing_dtype(self): with tm.raises_chained_assignment_error(): df.loc[2]["C"] = "foo" tm.assert_frame_equal(df, df_original) - with tm.raises_chained_assignment_error(extra_warnings=(FutureWarning,)): + with tm.raises_chained_assignment_error( + extra_warnings=(FutureWarning,), extra_match=(None,) + ): df["C"][2] = "foo" tm.assert_frame_equal(df, df_original) diff --git a/pandas/tests/io/parser/common/test_read_errors.py b/pandas/tests/io/parser/common/test_read_errors.py index 0827f64dccf46..bd47e045417ce 100644 --- a/pandas/tests/io/parser/common/test_read_errors.py +++ b/pandas/tests/io/parser/common/test_read_errors.py @@ -196,7 +196,6 @@ def test_warn_bad_lines(all_parsers): expected_warning = ParserWarning if parser.engine == "pyarrow": match_msg = "Expected 1 columns, but found 3: 1,2,3" - expected_warning = (ParserWarning, DeprecationWarning) with tm.assert_produces_warning( expected_warning, match=match_msg, check_stacklevel=False @@ -315,7 +314,6 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers): expected_warning = ParserWarning if parser.engine == "pyarrow": match_msg = "Expected 2 columns, but found 3: a,b,c" - expected_warning = (ParserWarning, DeprecationWarning) with tm.assert_produces_warning( expected_warning, match=match_msg, check_stacklevel=False diff --git a/pandas/tests/io/parser/test_parse_dates.py b/pandas/tests/io/parser/test_parse_dates.py index 0bc0c3e744db7..8968948df5fa9 100644 --- a/pandas/tests/io/parser/test_parse_dates.py +++ b/pandas/tests/io/parser/test_parse_dates.py @@ -343,7 +343,7 @@ def test_multiple_date_col(all_parsers, keep_date_col, request): "names": ["X0", "X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"], } with tm.assert_produces_warning( - (DeprecationWarning, FutureWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): result = parser.read_csv(StringIO(data), **kwds) @@ -724,7 +724,7 @@ def test_multiple_date_col_name_collision(all_parsers, data, parse_dates, msg): ) with pytest.raises(ValueError, match=msg): with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): parser.read_csv(StringIO(data), parse_dates=parse_dates) @@ -1248,14 +1248,14 @@ def test_multiple_date_col_named_index_compat(all_parsers): "Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): with_indices = parser.read_csv( StringIO(data), parse_dates={"nominal": [1, 2]}, index_col="nominal" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): with_names = parser.read_csv( StringIO(data), @@ -1280,13 +1280,13 @@ def test_multiple_date_col_multiple_index_compat(all_parsers): "Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): result = parser.read_csv( StringIO(data), index_col=["nominal", "ID"], parse_dates={"nominal": [1, 2]} ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): expected = parser.read_csv(StringIO(data), parse_dates={"nominal": [1, 2]}) @@ -2267,7 +2267,7 @@ def test_parse_dates_dict_format_two_columns(all_parsers, key, parse_dates): "Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): result = parser.read_csv( StringIO(data), date_format={key: "%d- %m-%Y"}, parse_dates=parse_dates diff --git a/pandas/tests/io/parser/usecols/test_parse_dates.py b/pandas/tests/io/parser/usecols/test_parse_dates.py index 75efe87c408c0..ab98857e0c178 100644 --- a/pandas/tests/io/parser/usecols/test_parse_dates.py +++ b/pandas/tests/io/parser/usecols/test_parse_dates.py @@ -146,7 +146,7 @@ def test_usecols_with_parse_dates4(all_parsers): "Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): result = parser.read_csv( StringIO(data), @@ -187,7 +187,7 @@ def test_usecols_with_parse_dates_and_names(all_parsers, usecols, names, request "Support for nested sequences for 'parse_dates' in pd.read_csv is deprecated" ) with tm.assert_produces_warning( - (FutureWarning, DeprecationWarning), match=depr_msg, check_stacklevel=False + FutureWarning, match=depr_msg, check_stacklevel=False ): result = parser.read_csv( StringIO(s), names=names, parse_dates=parse_dates, usecols=usecols diff --git a/pandas/tests/util/test_assert_produces_warning.py b/pandas/tests/util/test_assert_produces_warning.py index 80e3264690f81..5b917dbbe7ba7 100644 --- a/pandas/tests/util/test_assert_produces_warning.py +++ b/pandas/tests/util/test_assert_produces_warning.py @@ -42,7 +42,6 @@ def f(): warnings.warn("f2", RuntimeWarning) -@pytest.mark.filterwarnings("ignore:f1:FutureWarning") def test_assert_produces_warning_honors_filter(): # Raise by default. msg = r"Caused unexpected warning\(s\)" @@ -180,6 +179,44 @@ def test_match_multiple_warnings(): warnings.warn("Match this too", UserWarning) +def test_must_match_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + msg = "Did not see expected warning of class 'UserWarning'" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) + + +def test_must_match_multiple_warnings_messages(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + msg = r"The emitted warning messages are \[UserWarning\('Not this'\)\]" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) + warnings.warn("Not this", UserWarning) + + +def test_allow_partial_match_for_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning( + category, match=r"^Match this", must_find_all_warnings=False + ): + warnings.warn("Match this", FutureWarning) + + +def test_allow_partial_match_for_multiple_warnings_messages(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning( + category, match=r"^Match this", must_find_all_warnings=False + ): + warnings.warn("Match this", FutureWarning) + warnings.warn("Not this", UserWarning) + + def test_right_category_wrong_match_raises(pair_different_warnings): target_category, other_category = pair_different_warnings with pytest.raises(AssertionError, match="Did not see warning.*matching"):