From ae647cc845d69d8d67d1a1e640bca11db4e56330 Mon Sep 17 00:00:00 2001 From: cmp0xff Date: Fri, 16 May 2025 09:55:19 +0200 Subject: [PATCH 1/3] fix: #1212 Index.name currently has no typing https://github.com/pandas-dev/pandas-stubs/issues/1212#issuecomment-2884034784 --- pandas-stubs/core/indexes/base.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index cb0377ab8..d7961af00 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -298,13 +298,13 @@ class Index(IndexOpsMixin[S1]): def to_series(self, index=..., name: Hashable = ...) -> Series: ... def to_frame(self, index: bool = ..., name=...) -> DataFrame: ... @property - def name(self): ... + def name(self) -> Hashable | None: ... @name.setter def name(self, value) -> None: ... @property - def names(self) -> list[_str]: ... + def names(self) -> list[Hashable]: ... @names.setter - def names(self, names: list[_str]): ... + def names(self, names: Sequence[Hashable]) -> None: ... def set_names(self, names, *, level=..., inplace: bool = ...): ... @overload def rename(self, name, inplace: Literal[False] = False) -> Self: ... From 682aa19db121cd24cf2adf4dff394b114b9dfbab Mon Sep 17 00:00:00 2001 From: cmp0xff Date: Fri, 16 May 2025 09:58:37 +0200 Subject: [PATCH 2/3] fix: #1212 use typing from pandas.core.reshape.pivot --- pandas-stubs/core/frame.pyi | 12 +++++++---- pandas-stubs/core/reshape/pivot.pyi | 33 ++++++++++++++++------------- tests/test_frame.py | 11 ++++++++++ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index aeafd9de1..d83195ea5 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -32,7 +32,6 @@ from pandas import ( from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame from pandas.core.groupby.generic import DataFrameGroupBy -from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import ( Index, @@ -50,6 +49,11 @@ from pandas.core.indexing import ( _LocIndexer, ) from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg +from pandas.core.reshape.pivot import ( + PivotTableColumnsTypes, + PivotTableIndexTypes, + PivotTableValuesTypes, +) from pandas.core.series import Series from pandas.core.window import ( Expanding, @@ -1287,9 +1291,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): ) -> Self: ... def pivot_table( self, - values: _str | None | Sequence[_str] = ..., - index: _str | Grouper | Sequence | None = ..., - columns: _str | Grouper | Sequence | None = ..., + values: PivotTableValuesTypes = ..., + index: PivotTableIndexTypes = ..., + columns: PivotTableColumnsTypes = ..., aggfunc=..., fill_value: Scalar | None = ..., margins: _bool = ..., diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index 042539565..5dda29943 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -51,19 +51,22 @@ _NonIterableHashable: TypeAlias = ( | pd.Timedelta ) -_PivotTableIndexTypes: TypeAlias = Label | list[HashableT1] | Series | Grouper | None -_PivotTableColumnsTypes: TypeAlias = Label | list[HashableT2] | Series | Grouper | None +PivotTableIndexTypes: TypeAlias = Label | Sequence[HashableT1] | Series | Grouper | None +PivotTableColumnsTypes: TypeAlias = ( + Label | Sequence[HashableT2] | Series | Grouper | None +) +PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None _ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., - index: _PivotTableIndexTypes = ..., - columns: _PivotTableColumnsTypes = ..., + values: PivotTableValuesTypes = ..., + index: PivotTableIndexTypes = ..., + columns: PivotTableColumnsTypes = ..., aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -77,12 +80,12 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., + values: PivotTableValuesTypes = ..., *, index: Grouper, - columns: _PivotTableColumnsTypes | Index | npt.NDArray = ..., + columns: PivotTableColumnsTypes | Index | npt.NDArray = ..., aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -94,12 +97,12 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: Label | list[HashableT3] | None = ..., - index: _PivotTableIndexTypes | Index | npt.NDArray = ..., + values: PivotTableValuesTypes = ..., + index: PivotTableIndexTypes | Index | npt.NDArray = ..., *, columns: Grouper, aggfunc: ( - _PivotAggFunc | list[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] + _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., fill_value: Scalar | None = ..., margins: bool = ..., @@ -111,9 +114,9 @@ def pivot_table( def pivot( data: DataFrame, *, - index: _NonIterableHashable | list[HashableT1] = ..., - columns: _NonIterableHashable | list[HashableT2] = ..., - values: _NonIterableHashable | list[HashableT3] = ..., + index: _NonIterableHashable | Sequence[HashableT1] = ..., + columns: _NonIterableHashable | Sequence[HashableT2] = ..., + values: _NonIterableHashable | Sequence[HashableT3] = ..., ) -> DataFrame: ... @overload def crosstab( diff --git a/tests/test_frame.py b/tests/test_frame.py index ce05b7846..c70bdce23 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1337,6 +1337,17 @@ def test_types_pivot_table() -> None: ), pd.DataFrame, ) + check( + assert_type( + df.pivot_table( + index=df["col1"].name, + columns=df["col3"].name, + values=[df["col2"].name, df["col4"].name], + ), + pd.DataFrame, + ), + pd.DataFrame, + ) def test_pivot_table_sort(): From 4aa89d833b7eaa4232c7411bc7b8b19051eb49cd Mon Sep 17 00:00:00 2001 From: cmp0xff Date: Fri, 16 May 2025 16:11:29 +0200 Subject: [PATCH 3/3] fix(comment): https://github.com/pandas-dev/pandas-stubs/pull/1216#discussion_r2093091696 --- pandas-stubs/core/frame.pyi | 12 ++++++------ pandas-stubs/core/reshape/pivot.pyi | 22 ++++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index d83195ea5..1ee85d1c0 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -50,9 +50,9 @@ from pandas.core.indexing import ( ) from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg from pandas.core.reshape.pivot import ( - PivotTableColumnsTypes, - PivotTableIndexTypes, - PivotTableValuesTypes, + _PivotTableColumnsTypes, + _PivotTableIndexTypes, + _PivotTableValuesTypes, ) from pandas.core.series import Series from pandas.core.window import ( @@ -1291,9 +1291,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): ) -> Self: ... def pivot_table( self, - values: PivotTableValuesTypes = ..., - index: PivotTableIndexTypes = ..., - columns: PivotTableColumnsTypes = ..., + values: _PivotTableValuesTypes = ..., + index: _PivotTableIndexTypes = ..., + columns: _PivotTableColumnsTypes = ..., aggfunc=..., fill_value: Scalar | None = ..., margins: _bool = ..., diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index 5dda29943..5554ce6fb 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -51,20 +51,22 @@ _NonIterableHashable: TypeAlias = ( | pd.Timedelta ) -PivotTableIndexTypes: TypeAlias = Label | Sequence[HashableT1] | Series | Grouper | None -PivotTableColumnsTypes: TypeAlias = ( +_PivotTableIndexTypes: TypeAlias = ( + Label | Sequence[HashableT1] | Series | Grouper | None +) +_PivotTableColumnsTypes: TypeAlias = ( Label | Sequence[HashableT2] | Series | Grouper | None ) -PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None +_PivotTableValuesTypes: TypeAlias = Label | Sequence[HashableT3] | None _ExtendedAnyArrayLike: TypeAlias = AnyArrayLike | ArrayLike @overload def pivot_table( data: DataFrame, - values: PivotTableValuesTypes = ..., - index: PivotTableIndexTypes = ..., - columns: PivotTableColumnsTypes = ..., + values: _PivotTableValuesTypes = ..., + index: _PivotTableIndexTypes = ..., + columns: _PivotTableColumnsTypes = ..., aggfunc: ( _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., @@ -80,10 +82,10 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: PivotTableValuesTypes = ..., + values: _PivotTableValuesTypes = ..., *, index: Grouper, - columns: PivotTableColumnsTypes | Index | npt.NDArray = ..., + columns: _PivotTableColumnsTypes | Index | npt.NDArray = ..., aggfunc: ( _PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc] ) = ..., @@ -97,8 +99,8 @@ def pivot_table( @overload def pivot_table( data: DataFrame, - values: PivotTableValuesTypes = ..., - index: PivotTableIndexTypes | Index | npt.NDArray = ..., + values: _PivotTableValuesTypes = ..., + index: _PivotTableIndexTypes | Index | npt.NDArray = ..., *, columns: Grouper, aggfunc: (