diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a54a5827adacb..1c3dd35ef47f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -333,3 +333,13 @@ repos: additional_dependencies: - autotyping==22.9.0 - libcst==0.4.7 + - id: check-test-naming + name: check that test names start with 'test' + entry: python -m scripts.check_test_naming + types: [python] + files: ^pandas/tests + language: python + exclude: | + (?x) + ^pandas/tests/generic/test_generic.py # GH50380 + |^pandas/tests/io/json/test_readlines.py # GH50378 diff --git a/pandas/tests/computation/test_eval.py b/pandas/tests/computation/test_eval.py index a97205f2921fe..18559b9b4f899 100644 --- a/pandas/tests/computation/test_eval.py +++ b/pandas/tests/computation/test_eval.py @@ -353,7 +353,7 @@ def test_pow(self, lhs, rhs, engine, parser): expected = _eval_single_bin(middle, "**", rhs, engine) tm.assert_almost_equal(result, expected) - def check_single_invert_op(self, lhs, engine, parser): + def test_check_single_invert_op(self, lhs, engine, parser): # simple try: elb = lhs.astype(bool) diff --git a/pandas/tests/frame/methods/test_dtypes.py b/pandas/tests/frame/methods/test_dtypes.py index 87e6ed5b1b135..f3b77c27d75bd 100644 --- a/pandas/tests/frame/methods/test_dtypes.py +++ b/pandas/tests/frame/methods/test_dtypes.py @@ -15,13 +15,6 @@ import pandas._testing as tm -def _check_cast(df, v): - """ - Check if all dtypes of df are equal to v - """ - assert all(s.dtype.name == v for _, s in df.items()) - - class TestDataFrameDataTypes: def test_empty_frame_dtypes(self): empty_df = DataFrame() diff --git a/pandas/tests/frame/methods/test_to_timestamp.py b/pandas/tests/frame/methods/test_to_timestamp.py index acbb51fe79643..d1c10ce37bf3d 100644 --- a/pandas/tests/frame/methods/test_to_timestamp.py +++ b/pandas/tests/frame/methods/test_to_timestamp.py @@ -121,7 +121,7 @@ def test_to_timestamp_columns(self): assert result1.columns.freqstr == "AS-JAN" assert result2.columns.freqstr == "AS-JAN" - def to_timestamp_invalid_axis(self): + def test_to_timestamp_invalid_axis(self): index = period_range(freq="A", start="1/1/2001", end="12/1/2009") obj = DataFrame(np.random.randn(len(index), 5), index=index) diff --git a/pandas/tests/internals/test_internals.py b/pandas/tests/internals/test_internals.py index b2c2df52e5ce0..3044aecc26b4f 100644 --- a/pandas/tests/internals/test_internals.py +++ b/pandas/tests/internals/test_internals.py @@ -1323,10 +1323,6 @@ def test_period_can_hold_element(self, element): elem = element(dti) self.check_series_setitem(elem, pi, False) - def check_setting(self, elem, index: Index, inplace: bool): - self.check_series_setitem(elem, index, inplace) - self.check_frame_setitem(elem, index, inplace) - def check_can_hold_element(self, obj, elem, inplace: bool): blk = obj._mgr.blocks[0] if inplace: @@ -1350,23 +1346,6 @@ def check_series_setitem(self, elem, index: Index, inplace: bool): else: assert ser.dtype == object - def check_frame_setitem(self, elem, index: Index, inplace: bool): - arr = index._data.copy() - df = DataFrame(arr) - - self.check_can_hold_element(df, elem, inplace) - - if is_scalar(elem): - df.iloc[0, 0] = elem - else: - df.iloc[: len(elem), 0] = elem - - if inplace: - # assertion here implies setting was done inplace - assert df._mgr.arrays[0] is arr - else: - assert df.dtypes[0] == object - class TestShouldStore: def test_should_store_categorical(self): diff --git a/pandas/tests/io/test_feather.py b/pandas/tests/io/test_feather.py index e58df00c65608..88bf04f518e12 100644 --- a/pandas/tests/io/test_feather.py +++ b/pandas/tests/io/test_feather.py @@ -113,10 +113,11 @@ def test_read_columns(self): columns = ["col1", "col3"] self.check_round_trip(df, expected=df[columns], columns=columns) - def read_columns_different_order(self): + def test_read_columns_different_order(self): # GH 33878 df = pd.DataFrame({"A": [1, 2], "B": ["x", "y"], "C": [True, False]}) - self.check_round_trip(df, columns=["B", "A"]) + expected = df[["B", "A"]] + self.check_round_trip(df, expected, columns=["B", "A"]) def test_unsupported_other(self): diff --git a/pandas/tests/reshape/concat/test_append_common.py b/pandas/tests/reshape/concat/test_append_common.py index e0275fa85d66e..938d18be8657a 100644 --- a/pandas/tests/reshape/concat/test_append_common.py +++ b/pandas/tests/reshape/concat/test_append_common.py @@ -55,21 +55,6 @@ def item(self, request): item2 = item - def _check_expected_dtype(self, obj, label): - """ - Check whether obj has expected dtype depending on label - considering not-supported dtypes - """ - if isinstance(obj, Index): - assert obj.dtype == label - elif isinstance(obj, Series): - if label.startswith("period"): - assert obj.dtype == "Period[M]" - else: - assert obj.dtype == label - else: - raise ValueError - def test_dtypes(self, item, index_or_series): # to confirm test case covers intended dtypes typ, vals = item diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index c73737dad89aa..0dc3ef25a39a4 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -76,7 +76,7 @@ def test_invert_array(): @pytest.mark.parametrize( "s", [pd.Series([1, 2, 3]), pd.Series(pd.date_range("2019", periods=3, tz="UTC"))] ) -def non_object_dtype(s): +def test_non_object_dtype(s): result = s.explode() tm.assert_series_equal(result, s) diff --git a/pandas/tests/strings/test_cat.py b/pandas/tests/strings/test_cat.py index 01c5bf25e0601..ff2898107a9e4 100644 --- a/pandas/tests/strings/test_cat.py +++ b/pandas/tests/strings/test_cat.py @@ -11,7 +11,13 @@ _testing as tm, concat, ) -from pandas.tests.strings.test_strings import assert_series_or_index_equal + + +def assert_series_or_index_equal(left, right): + if isinstance(left, Series): + tm.assert_series_equal(left, right) + else: # Index + tm.assert_index_equal(left, right) @pytest.mark.parametrize("other", [None, Series, Index]) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 4385f71dc653f..a9335e156d9db 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -26,13 +26,6 @@ def test_startswith_endswith_non_str_patterns(pattern): ser.str.endswith(pattern) -def assert_series_or_index_equal(left, right): - if isinstance(left, Series): - tm.assert_series_equal(left, right) - else: # Index - tm.assert_index_equal(left, right) - - # test integer/float dtypes (inferred by constructor) and mixed diff --git a/pandas/tests/tseries/offsets/test_dst.py b/pandas/tests/tseries/offsets/test_dst.py index 9c6d6a686e9a5..347c91a67ebb5 100644 --- a/pandas/tests/tseries/offsets/test_dst.py +++ b/pandas/tests/tseries/offsets/test_dst.py @@ -30,13 +30,18 @@ YearEnd, ) -from pandas.tests.tseries.offsets.test_offsets import get_utc_offset_hours from pandas.util.version import Version # error: Module has no attribute "__version__" pytz_version = Version(pytz.__version__) # type: ignore[attr-defined] +def get_utc_offset_hours(ts): + # take a Timestamp and compute total hours of utc offset + o = ts.utcoffset() + return (o.days * 24 * 3600 + o.seconds) / 3600.0 + + class TestDST: # one microsecond before the DST transition diff --git a/pandas/tests/tseries/offsets/test_offsets.py b/pandas/tests/tseries/offsets/test_offsets.py index 135227d66d541..933723edd6e66 100644 --- a/pandas/tests/tseries/offsets/test_offsets.py +++ b/pandas/tests/tseries/offsets/test_offsets.py @@ -900,12 +900,6 @@ def test_str_for_named_is_name(self): assert offset.freqstr == name -def get_utc_offset_hours(ts): - # take a Timestamp and compute total hours of utc offset - o = ts.utcoffset() - return (o.days * 24 * 3600 + o.seconds) / 3600.0 - - # --------------------------------------------------------------------- diff --git a/pandas/tests/util/test_assert_frame_equal.py b/pandas/tests/util/test_assert_frame_equal.py index 1fe2a7428486e..a6e29e243b0c8 100644 --- a/pandas/tests/util/test_assert_frame_equal.py +++ b/pandas/tests/util/test_assert_frame_equal.py @@ -34,24 +34,6 @@ def _assert_frame_equal_both(a, b, **kwargs): tm.assert_frame_equal(b, a, **kwargs) -def _assert_not_frame_equal(a, b, **kwargs): - """ - Check that two DataFrame are not equal. - - Parameters - ---------- - a : DataFrame - The first DataFrame to compare. - b : DataFrame - The second DataFrame to compare. - kwargs : dict - The arguments passed to `tm.assert_frame_equal`. - """ - msg = "The two DataFrames were equal when they shouldn't have been" - with pytest.raises(AssertionError, match=msg): - tm.assert_frame_equal(a, b, **kwargs) - - @pytest.mark.parametrize("check_like", [True, False]) def test_frame_equal_row_order_mismatch(check_like, obj_fixture): df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) diff --git a/scripts/check_test_naming.py b/scripts/check_test_naming.py new file mode 100644 index 0000000000000..33890feb8692d --- /dev/null +++ b/scripts/check_test_naming.py @@ -0,0 +1,152 @@ +""" +Check that test names start with `test`, and that test classes start with `Test`. + +This is meant to be run as a pre-commit hook - to run it manually, you can do: + + pre-commit run check-test-naming --all-files + +NOTE: if this finds a false positive, you can add the comment `# not a test` to the +class or function definition. Though hopefully that shouldn't be necessary. +""" +from __future__ import annotations + +import argparse +import ast +import os +from pathlib import Path +import sys +from typing import ( + Iterator, + Sequence, +) + +PRAGMA = "# not a test" + + +def _find_names(node: ast.Module) -> Iterator[str]: + for _node in ast.walk(node): + if isinstance(_node, ast.Name): + yield _node.id + elif isinstance(_node, ast.Attribute): + yield _node.attr + + +def _is_fixture(node: ast.expr) -> bool: + if isinstance(node, ast.Call): + node = node.func + return ( + isinstance(node, ast.Attribute) + and node.attr == "fixture" + and isinstance(node.value, ast.Name) + and node.value.id == "pytest" + ) + + +def _is_register_dtype(node): + return isinstance(node, ast.Name) and node.id == "register_extension_dtype" + + +def is_misnamed_test_func( + node: ast.expr | ast.stmt, names: Sequence[str], line: str +) -> bool: + return ( + isinstance(node, ast.FunctionDef) + and not node.name.startswith("test") + and names.count(node.name) == 0 + and not any(_is_fixture(decorator) for decorator in node.decorator_list) + and PRAGMA not in line + and node.name + not in ("teardown_method", "setup_method", "teardown_class", "setup_class") + ) + + +def is_misnamed_test_class( + node: ast.expr | ast.stmt, names: Sequence[str], line: str +) -> bool: + return ( + isinstance(node, ast.ClassDef) + and not node.name.startswith("Test") + and names.count(node.name) == 0 + and not any(_is_register_dtype(decorator) for decorator in node.decorator_list) + and PRAGMA not in line + ) + + +def main(content: str, file: str) -> int: + lines = content.splitlines() + tree = ast.parse(content) + names = list(_find_names(tree)) + ret = 0 + for node in tree.body: + if is_misnamed_test_func(node, names, lines[node.lineno - 1]): + print( + f"{file}:{node.lineno}:{node.col_offset} " + "found test function which does not start with 'test'" + ) + ret = 1 + elif is_misnamed_test_class(node, names, lines[node.lineno - 1]): + print( + f"{file}:{node.lineno}:{node.col_offset} " + "found test class which does not start with 'Test'" + ) + ret = 1 + if ( + isinstance(node, ast.ClassDef) + and names.count(node.name) == 0 + and not any( + _is_register_dtype(decorator) for decorator in node.decorator_list + ) + and PRAGMA not in lines[node.lineno - 1] + ): + for _node in node.body: + if is_misnamed_test_func(_node, names, lines[_node.lineno - 1]): + # It could be that this function is used somewhere by the + # parent class. For example, there might be a base class + # with + # + # class Foo: + # def foo(self): + # assert 1+1==2 + # def test_foo(self): + # self.foo() + # + # and then some subclass overwrites `foo`. So, we check that + # `self.foo` doesn't appear in any of the test classes. + # Note some false negatives might get through, but that's OK. + # This is good enough that has helped identify several examples + # of tests not being run. + assert isinstance(_node, ast.FunctionDef) # help mypy + should_continue = False + for _file in (Path("pandas") / "tests").rglob("*.py"): + with open(os.path.join(_file)) as fd: + _content = fd.read() + if f"self.{_node.name}" in _content: + should_continue = True + break + if should_continue: + continue + + print( + f"{file}:{_node.lineno}:{_node.col_offset} " + "found test function which does not start with 'test'" + ) + ret = 1 + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("paths", nargs="*") + args = parser.parse_args() + + ret = 0 + + for file in args.paths: + filename = os.path.basename(file) + if not (filename.startswith("test") and filename.endswith(".py")): + continue + with open(file, encoding="utf-8") as fd: + content = fd.read() + ret |= main(content, file) + + sys.exit(ret) diff --git a/scripts/tests/test_check_test_naming.py b/scripts/tests/test_check_test_naming.py new file mode 100644 index 0000000000000..9ddaf2fe2a97d --- /dev/null +++ b/scripts/tests/test_check_test_naming.py @@ -0,0 +1,54 @@ +import pytest + +from scripts.check_test_naming import main + + +@pytest.mark.parametrize( + "src, expected_out, expected_ret", + [ + ( + "def foo(): pass\n", + "t.py:1:0 found test function which does not start with 'test'\n", + 1, + ), + ( + "class Foo:\n def test_foo(): pass\n", + "t.py:1:0 found test class which does not start with 'Test'\n", + 1, + ), + ("def test_foo(): pass\n", "", 0), + ( + "class TestFoo:\n def foo(): pass\n", + "t.py:2:4 found test function which does not start with 'test'\n", + 1, + ), + ("class TestFoo:\n def test_foo(): pass\n", "", 0), + ( + "class Foo:\n def foo(): pass\n", + "t.py:1:0 found test class which does not start with 'Test'\n" + "t.py:2:4 found test function which does not start with 'test'\n", + 1, + ), + ( + "def foo():\n pass\ndef test_foo():\n foo()\n", + "", + 0, + ), + ( + "class Foo: # not a test\n" + " pass\n" + "def test_foo():\n" + " Class.foo()\n", + "", + 0, + ), + ("@pytest.fixture\ndef foo(): pass\n", "", 0), + ("@pytest.fixture()\ndef foo(): pass\n", "", 0), + ("@register_extension_dtype\nclass Foo: pass\n", "", 0), + ], +) +def test_main(capsys, src, expected_out, expected_ret): + ret = main(src, "t.py") + out, _ = capsys.readouterr() + assert out == expected_out + assert ret == expected_ret