From ca848bb563e7ce409263709923c610ac6affbf3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Thu, 30 Jun 2022 23:25:19 -0400 Subject: [PATCH 1/4] TYP: fix a few errors found by pandas-stub --- pandas/_testing/asserters.py | 2 +- pandas/core/generic.py | 2 +- pandas/core/groupby/groupby.py | 7 ++++++- pandas/core/indexes/multi.py | 5 ++++- pandas/plotting/_matplotlib/groupby.py | 4 +++- 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index 90ee600c1967d..7c89946a0c12f 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -864,7 +864,7 @@ def assert_series_equal( check_dtype: bool | Literal["equiv"] = True, check_index_type="equiv", check_series_type=True, - check_less_precise=no_default, + check_less_precise: bool | int | NoDefault = no_default, check_names=True, check_exact=False, check_datetimelike_compat=False, diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f896169d0ae44..0fa6bc952e901 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1050,7 +1050,7 @@ def _rename( return result.__finalize__(self, method="rename") @rewrite_axis_style_signature("mapper", [("copy", True), ("inplace", False)]) - def rename_axis(self, mapper=lib.no_default, **kwargs): + def rename_axis(self, mapper: IndexLabel = lib.no_default, **kwargs): """ Set the name of the axis for the index or columns. diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index c2098fbe93a56..3f8b0e2e2bf77 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4050,7 +4050,12 @@ def _reindex_output( # "ndarray[Any, dtype[floating[_64Bit]]]"; expected "Index" levels_list.append(qs) # type: ignore[arg-type] names = names + [None] - index, _ = MultiIndex.from_product(levels_list, names=names).sortlevel() + # error: Argument "names" to "from_product" of "MultiIndex" has + # incompatible type "List[Hashable]"; expected "Union[Sequence[ + # Optional[str]], Literal[_NoDefault.no_default]]" + index, _ = MultiIndex.from_product( + levels_list, names=names # type: ignore[arg-type] + ).sortlevel() if self.as_index: # Always holds for SeriesGroupBy unless GH#36507 is implemented diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 4da39318579eb..8010801dd888b 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -578,7 +578,10 @@ def from_tuples( @classmethod def from_product( - cls, iterables, sortorder=None, names=lib.no_default + cls, + iterables: Sequence[Iterable[Hashable]], + sortorder: int | None = None, + names: Sequence[str | None] | lib.NoDefault = lib.no_default, ) -> MultiIndex: """ Make a MultiIndex from the cartesian product of multiple iterables. diff --git a/pandas/plotting/_matplotlib/groupby.py b/pandas/plotting/_matplotlib/groupby.py index 1b16eefb360ae..4f1cd3f38343a 100644 --- a/pandas/plotting/_matplotlib/groupby.py +++ b/pandas/plotting/_matplotlib/groupby.py @@ -112,7 +112,9 @@ def reconstruct_data_with_by( data_list = [] for key, group in grouped: - columns = MultiIndex.from_product([[key], cols]) + # error: List item 1 has incompatible type "Union[Hashable, + # Sequence[Hashable]]"; expected "Iterable[Hashable]" + columns = MultiIndex.from_product([[key], cols]) # type: ignore[list-item] sub_group = group[cols] sub_group.columns = columns data_list.append(sub_group) From cdd9d367121821751b1a930345e62cf95cfc94b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Jul 2022 15:08:43 -0400 Subject: [PATCH 2/4] from_product --- pandas/core/generic.py | 4 +++- pandas/core/groupby/groupby.py | 7 +------ pandas/core/indexes/multi.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 0fa6bc952e901..ef7e8421a0cdd 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1050,7 +1050,9 @@ def _rename( return result.__finalize__(self, method="rename") @rewrite_axis_style_signature("mapper", [("copy", True), ("inplace", False)]) - def rename_axis(self, mapper: IndexLabel = lib.no_default, **kwargs): + def rename_axis( + self, mapper: IndexLabel | lib.NoDefault = lib.no_default, **kwargs + ): """ Set the name of the axis for the index or columns. diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3f8b0e2e2bf77..c2098fbe93a56 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -4050,12 +4050,7 @@ def _reindex_output( # "ndarray[Any, dtype[floating[_64Bit]]]"; expected "Index" levels_list.append(qs) # type: ignore[arg-type] names = names + [None] - # error: Argument "names" to "from_product" of "MultiIndex" has - # incompatible type "List[Hashable]"; expected "Union[Sequence[ - # Optional[str]], Literal[_NoDefault.no_default]]" - index, _ = MultiIndex.from_product( - levels_list, names=names # type: ignore[arg-type] - ).sortlevel() + index, _ = MultiIndex.from_product(levels_list, names=names).sortlevel() if self.as_index: # Always holds for SeriesGroupBy unless GH#36507 is implemented diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 8010801dd888b..583612b4659b6 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -581,7 +581,7 @@ def from_product( cls, iterables: Sequence[Iterable[Hashable]], sortorder: int | None = None, - names: Sequence[str | None] | lib.NoDefault = lib.no_default, + names: Sequence[Hashable] | lib.NoDefault = lib.no_default, ) -> MultiIndex: """ Make a MultiIndex from the cartesian product of multiple iterables. From 1f7867adc288e06ccd6cf483a9e2092a24004ca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Jul 2022 18:06:26 -0400 Subject: [PATCH 3/4] adjust tests --- pandas/tests/util/test_assert_attr_equal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/tests/util/test_assert_attr_equal.py b/pandas/tests/util/test_assert_attr_equal.py index 115ef58e085cc..bbbb0bf2172b1 100644 --- a/pandas/tests/util/test_assert_attr_equal.py +++ b/pandas/tests/util/test_assert_attr_equal.py @@ -10,7 +10,7 @@ def test_assert_attr_equal(nulls_fixture): obj = SimpleNamespace() obj.na_value = nulls_fixture - assert tm.assert_attr_equal("na_value", obj, obj) + tm.assert_attr_equal("na_value", obj, obj) def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): @@ -21,13 +21,13 @@ def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): obj2.na_value = nulls_fixture2 if nulls_fixture is nulls_fixture2: - assert tm.assert_attr_equal("na_value", obj, obj2) + tm.assert_attr_equal("na_value", obj, obj2) elif is_float(nulls_fixture) and is_float(nulls_fixture2): # we consider float("nan") and np.float64("nan") to be equivalent - assert tm.assert_attr_equal("na_value", obj, obj2) + tm.assert_attr_equal("na_value", obj, obj2) elif type(nulls_fixture) is type(nulls_fixture2): # e.g. Decimal("NaN") - assert tm.assert_attr_equal("na_value", obj, obj2) + tm.assert_attr_equal("na_value", obj, obj2) else: with pytest.raises(AssertionError, match='"na_value" are different'): tm.assert_attr_equal("na_value", obj, obj2) From 878bad068cba464ea5986b1b64f4af451a425be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Jul 2022 18:08:00 -0400 Subject: [PATCH 4/4] Revert "adjust tests" This reverts commit 1f7867adc288e06ccd6cf483a9e2092a24004ca2. --- pandas/tests/util/test_assert_attr_equal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/tests/util/test_assert_attr_equal.py b/pandas/tests/util/test_assert_attr_equal.py index bbbb0bf2172b1..115ef58e085cc 100644 --- a/pandas/tests/util/test_assert_attr_equal.py +++ b/pandas/tests/util/test_assert_attr_equal.py @@ -10,7 +10,7 @@ def test_assert_attr_equal(nulls_fixture): obj = SimpleNamespace() obj.na_value = nulls_fixture - tm.assert_attr_equal("na_value", obj, obj) + assert tm.assert_attr_equal("na_value", obj, obj) def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): @@ -21,13 +21,13 @@ def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): obj2.na_value = nulls_fixture2 if nulls_fixture is nulls_fixture2: - tm.assert_attr_equal("na_value", obj, obj2) + assert tm.assert_attr_equal("na_value", obj, obj2) elif is_float(nulls_fixture) and is_float(nulls_fixture2): # we consider float("nan") and np.float64("nan") to be equivalent - tm.assert_attr_equal("na_value", obj, obj2) + assert tm.assert_attr_equal("na_value", obj, obj2) elif type(nulls_fixture) is type(nulls_fixture2): # e.g. Decimal("NaN") - tm.assert_attr_equal("na_value", obj, obj2) + assert tm.assert_attr_equal("na_value", obj, obj2) else: with pytest.raises(AssertionError, match='"na_value" are different'): tm.assert_attr_equal("na_value", obj, obj2)