Skip to content

TYP: SelectionMixin #41384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pandas.core.generic import NDFrame
from pandas.core.groupby.generic import (
DataFrameGroupBy,
GroupBy,
SeriesGroupBy,
)
from pandas.core.indexes.base import Index
Expand Down Expand Up @@ -158,6 +159,7 @@
AggObjType = Union[
"Series",
"DataFrame",
"GroupBy",
"SeriesGroupBy",
"DataFrameGroupBy",
"BaseWindow",
Expand Down
10 changes: 3 additions & 7 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AggFuncTypeDict,
AggObjType,
Axis,
FrameOrSeries,
FrameOrSeriesUnion,
)
from pandas.util._decorators import cache_readonly
Expand Down Expand Up @@ -60,10 +61,7 @@
Index,
Series,
)
from pandas.core.groupby import (
DataFrameGroupBy,
SeriesGroupBy,
)
from pandas.core.groupby import GroupBy
from pandas.core.resample import Resampler
from pandas.core.window.rolling import BaseWindow

Expand Down Expand Up @@ -1089,11 +1087,9 @@ def apply_standard(self) -> FrameOrSeriesUnion:


class GroupByApply(Apply):
obj: SeriesGroupBy | DataFrameGroupBy

def __init__(
self,
obj: SeriesGroupBy | DataFrameGroupBy,
obj: GroupBy[FrameOrSeries],
func: AggFuncType,
args,
kwargs,
Expand Down
54 changes: 18 additions & 36 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Hashable,
TypeVar,
cast,
)
Expand All @@ -19,6 +21,7 @@
ArrayLike,
Dtype,
DtypeObj,
FrameOrSeries,
IndexLabel,
Shape,
final,
Expand Down Expand Up @@ -163,13 +166,15 @@ class SpecificationError(Exception):
pass


class SelectionMixin:
class SelectionMixin(Generic[FrameOrSeries]):
"""
mixin implementing the selection & aggregation interface on a group-like
object sub-classes need to define: obj, exclusions
"""

obj: FrameOrSeries
_selection: IndexLabel | None = None
exclusions: frozenset[Hashable]
_internal_names = ["_cache", "__setstate__"]
_internal_names_set = set(_internal_names)

Expand All @@ -194,15 +199,10 @@ def _selection_list(self):

@cache_readonly
def _selected_obj(self):
# error: "SelectionMixin" has no attribute "obj"
if self._selection is None or isinstance(
self.obj, ABCSeries # type: ignore[attr-defined]
):
# error: "SelectionMixin" has no attribute "obj"
return self.obj # type: ignore[attr-defined]
if self._selection is None or isinstance(self.obj, ABCSeries):
return self.obj
else:
# error: "SelectionMixin" has no attribute "obj"
return self.obj[self._selection] # type: ignore[attr-defined]
return self.obj[self._selection]

@cache_readonly
def ndim(self) -> int:
Expand All @@ -211,49 +211,31 @@ def ndim(self) -> int:
@final
@cache_readonly
def _obj_with_exclusions(self):
# error: "SelectionMixin" has no attribute "obj"
if self._selection is not None and isinstance(
self.obj, ABCDataFrame # type: ignore[attr-defined]
):
# error: "SelectionMixin" has no attribute "obj"
return self.obj.reindex( # type: ignore[attr-defined]
columns=self._selection_list
)
if self._selection is not None and isinstance(self.obj, ABCDataFrame):
return self.obj.reindex(columns=self._selection_list)

# error: "SelectionMixin" has no attribute "exclusions"
if len(self.exclusions) > 0: # type: ignore[attr-defined]
# error: "SelectionMixin" has no attribute "obj"
# error: "SelectionMixin" has no attribute "exclusions"
return self.obj.drop(self.exclusions, axis=1) # type: ignore[attr-defined]
if len(self.exclusions) > 0:
return self.obj.drop(self.exclusions, axis=1)
else:
# error: "SelectionMixin" has no attribute "obj"
return self.obj # type: ignore[attr-defined]
return self.obj

def __getitem__(self, key):
if self._selection is not None:
raise IndexError(f"Column(s) {self._selection} already selected")

if isinstance(key, (list, tuple, ABCSeries, ABCIndex, np.ndarray)):
# error: "SelectionMixin" has no attribute "obj"
if len(
self.obj.columns.intersection(key) # type: ignore[attr-defined]
) != len(key):
# error: "SelectionMixin" has no attribute "obj"
bad_keys = list(
set(key).difference(self.obj.columns) # type: ignore[attr-defined]
)
if len(self.obj.columns.intersection(key)) != len(key):
bad_keys = list(set(key).difference(self.obj.columns))
raise KeyError(f"Columns not found: {str(bad_keys)[1:-1]}")
return self._gotitem(list(key), ndim=2)

elif not getattr(self, "as_index", False):
# error: "SelectionMixin" has no attribute "obj"
if key not in self.obj.columns: # type: ignore[attr-defined]
if key not in self.obj.columns:
raise KeyError(f"Column not found: {key}")
return self._gotitem(key, ndim=2)

else:
# error: "SelectionMixin" has no attribute "obj"
if key not in self.obj: # type: ignore[attr-defined]
if key not in self.obj:
raise KeyError(f"Column not found: {key}")
return self._gotitem(key, ndim=1)

Expand Down
9 changes: 3 additions & 6 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class providing the base-class of operations.
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Hashable,
Iterable,
Iterator,
Expand Down Expand Up @@ -567,7 +566,7 @@ def group_selection_context(groupby: GroupBy) -> Iterator[GroupBy]:
]


class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
class BaseGroupBy(PandasObject, SelectionMixin[FrameOrSeries]):
_group_selection: IndexLabel | None = None
_apply_allowlist: frozenset[str] = frozenset()
_hidden_attrs = PandasObject._hidden_attrs | {
Expand All @@ -588,7 +587,6 @@ class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):

axis: int
grouper: ops.BaseGrouper
obj: FrameOrSeries
group_keys: bool

@final
Expand Down Expand Up @@ -840,7 +838,6 @@ class GroupBy(BaseGroupBy[FrameOrSeries]):
more
"""

obj: FrameOrSeries
grouper: ops.BaseGrouper
as_index: bool

Expand All @@ -852,7 +849,7 @@ def __init__(
axis: int = 0,
level: IndexLabel | None = None,
grouper: ops.BaseGrouper | None = None,
exclusions: set[Hashable] | None = None,
exclusions: frozenset[Hashable] | None = None,
selection: IndexLabel | None = None,
as_index: bool = True,
sort: bool = True,
Expand Down Expand Up @@ -901,7 +898,7 @@ def __init__(
self.obj = obj
self.axis = obj._get_axis_number(axis)
self.grouper = grouper
self.exclusions = exclusions or set()
self.exclusions = frozenset(exclusions) if exclusions else frozenset()

def __getattr__(self, attr: str):
if attr in self._internal_names_set:
Expand Down
10 changes: 5 additions & 5 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def get_grouper(
mutated: bool = False,
validate: bool = True,
dropna: bool = True,
) -> tuple[ops.BaseGrouper, set[Hashable], FrameOrSeries]:
) -> tuple[ops.BaseGrouper, frozenset[Hashable], FrameOrSeries]:
"""
Create and return a BaseGrouper, which is an internal
mapping of how to create the grouper indexers.
Expand Down Expand Up @@ -728,13 +728,13 @@ def get_grouper(
if isinstance(key, Grouper):
binner, grouper, obj = key._get_grouper(obj, validate=False)
if key.key is None:
return grouper, set(), obj
return grouper, frozenset(), obj
else:
return grouper, {key.key}, obj
return grouper, frozenset({key.key}), obj

# already have a BaseGrouper, just return it
elif isinstance(key, ops.BaseGrouper):
return key, set(), obj
return key, frozenset(), obj

if not isinstance(key, list):
keys = [key]
Expand Down Expand Up @@ -861,7 +861,7 @@ def is_in_obj(gpr) -> bool:
grouper = ops.BaseGrouper(
group_axis, groupings, sort=sort, mutated=mutated, dropna=dropna
)
return grouper, exclusions, obj
return grouper, frozenset(exclusions), obj


def _is_label_like(val) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TYPE_CHECKING,
Any,
Callable,
Hashable,
)
import warnings

Expand Down Expand Up @@ -109,7 +110,7 @@ class BaseWindow(SelectionMixin):
"""Provides utilities for performing windowing operations."""

_attributes: list[str] = []
exclusions: set[str] = set()
exclusions: frozenset[Hashable] = frozenset()

def __init__(
self,
Expand Down