Skip to content

Commit 9105941

Browse files
authored
feat: introduce ElementOpsMixin (#1424)
* refactor: simplify add * refactor: simplify mul * chore: remove S2_CO * refactor: simplify typevar * typo * test_dist * fix(comment): #1424 (comment) * attempt to simplify the script
1 parent 1341968 commit 9105941

File tree

6 files changed

+223
-320
lines changed

6 files changed

+223
-320
lines changed

pandas-stubs/_typing.pyi

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,15 +905,11 @@ SeriesDType: TypeAlias = (
905905
| datetime.timedelta # includes pd.Timedelta
906906
)
907907
S1 = TypeVar("S1", bound=SeriesDType, default=Any)
908-
S1_CT_NDT = TypeVar(
909-
"S1_CT_NDT", bound=SeriesDTypeNoDateTime, default=Any, contravariant=True
910-
)
911-
S1_CO = TypeVar("S1_CO", bound=SeriesDType, default=Any, covariant=True)
912-
S1_CT = TypeVar("S1_CT", bound=SeriesDType, default=Any, contravariant=True)
913908
# Like S1, but without `default=Any`.
914909
S2 = TypeVar("S2", bound=SeriesDType)
915910
S2_CT = TypeVar("S2_CT", bound=SeriesDType, contravariant=True)
916-
S2_CO_NSDT = TypeVar("S2_CO_NSDT", bound=SeriesDTypeNoStrDateTime, covariant=True)
911+
S2_CT_NDT = TypeVar("S2_CT_NDT", bound=SeriesDTypeNoDateTime, contravariant=True)
912+
S2_NSDT = TypeVar("S2_NSDT", bound=SeriesDTypeNoStrDateTime)
917913
S3 = TypeVar("S3", bound=SeriesDType)
918914

919915
# Constraint, instead of bound

pandas-stubs/core/base.pyi

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,34 @@ from typing import (
77
Any,
88
Generic,
99
Literal,
10+
Protocol,
1011
TypeAlias,
1112
final,
1213
overload,
14+
type_check_only,
1315
)
1416

17+
from _typeshed import _T_contra
1518
import numpy as np
16-
from pandas import (
17-
Index,
18-
Series,
19-
)
2019
from pandas.core.arraylike import OpsMixin
2120
from pandas.core.arrays import ExtensionArray
2221
from pandas.core.arrays.categorical import Categorical
2322
from pandas.core.indexes.accessors import ArrayDescriptor
23+
from pandas.core.indexes.base import Index
24+
from pandas.core.series import Series
2425
from typing_extensions import Self
2526

27+
from pandas._libs.tslibs.timedeltas import Timedelta
2628
from pandas._typing import (
2729
S1,
30+
S2,
2831
ArrayLike,
2932
AxisIndex,
3033
DropKeep,
3134
DTypeLike,
3235
GenericT,
3336
GenericT_co,
37+
Just,
3438
NDFrameT,
3539
Scalar,
3640
SequenceNotStr,
@@ -176,3 +180,108 @@ NumListLike: TypeAlias = (
176180
| Sequence[complex]
177181
| IndexOpsMixin[complex]
178182
)
183+
184+
@type_check_only
185+
class ElementOpsMixin(Generic[S2]):
186+
@overload
187+
def _proto_add(
188+
self: ElementOpsMixin[bool], other: bool | np.bool_
189+
) -> ElementOpsMixin[bool]: ...
190+
@overload
191+
def _proto_add(
192+
self: ElementOpsMixin[int], other: int | np.integer
193+
) -> ElementOpsMixin[int]: ...
194+
@overload
195+
def _proto_add(
196+
self: ElementOpsMixin[float], other: float | np.floating
197+
) -> ElementOpsMixin[float]: ...
198+
@overload
199+
def _proto_add(
200+
self: ElementOpsMixin[complex], other: complex | np.complexfloating
201+
) -> ElementOpsMixin[complex]: ...
202+
@overload
203+
def _proto_add(self: ElementOpsMixin[str], other: str) -> ElementOpsMixin[str]: ...
204+
@overload
205+
def _proto_radd(
206+
self: ElementOpsMixin[bool], other: bool | np.bool_
207+
) -> ElementOpsMixin[bool]: ...
208+
@overload
209+
def _proto_radd(
210+
self: ElementOpsMixin[int], other: int | np.integer
211+
) -> ElementOpsMixin[int]: ...
212+
@overload
213+
def _proto_radd(
214+
self: ElementOpsMixin[float], other: float | np.floating
215+
) -> ElementOpsMixin[float]: ...
216+
@overload
217+
def _proto_radd(
218+
self: ElementOpsMixin[complex], other: complex | np.complexfloating
219+
) -> ElementOpsMixin[complex]: ...
220+
@overload
221+
def _proto_radd(self: ElementOpsMixin[str], other: str) -> ElementOpsMixin[str]: ...
222+
@overload
223+
def _proto_mul(
224+
self: ElementOpsMixin[bool], other: bool | np.bool_
225+
) -> ElementOpsMixin[bool]: ...
226+
@overload
227+
def _proto_mul(
228+
self: ElementOpsMixin[int], other: int | np.integer
229+
) -> ElementOpsMixin[int]: ...
230+
@overload
231+
def _proto_mul(
232+
self: ElementOpsMixin[float], other: float | np.floating
233+
) -> ElementOpsMixin[float]: ...
234+
@overload
235+
def _proto_mul(
236+
self: ElementOpsMixin[complex], other: complex | np.complexfloating
237+
) -> ElementOpsMixin[complex]: ...
238+
@overload
239+
def _proto_mul(
240+
self: ElementOpsMixin[Timedelta],
241+
other: Just[int] | Just[float] | np.integer | np.floating,
242+
) -> ElementOpsMixin[Timedelta]: ...
243+
@overload
244+
def _proto_mul(
245+
self: ElementOpsMixin[str], other: Just[int] | np.integer
246+
) -> ElementOpsMixin[str]: ...
247+
@overload
248+
def _proto_rmul(
249+
self: ElementOpsMixin[bool], other: bool | np.bool_
250+
) -> ElementOpsMixin[bool]: ...
251+
@overload
252+
def _proto_rmul(
253+
self: ElementOpsMixin[int], other: int | np.integer
254+
) -> ElementOpsMixin[int]: ...
255+
@overload
256+
def _proto_rmul(
257+
self: ElementOpsMixin[float], other: float | np.floating
258+
) -> ElementOpsMixin[float]: ...
259+
@overload
260+
def _proto_rmul(
261+
self: ElementOpsMixin[complex], other: complex | np.complexfloating
262+
) -> ElementOpsMixin[complex]: ...
263+
@overload
264+
def _proto_rmul(
265+
self: ElementOpsMixin[Timedelta],
266+
other: Just[int] | Just[float] | np.integer | np.floating,
267+
) -> ElementOpsMixin[Timedelta]: ...
268+
@overload
269+
def _proto_rmul(
270+
self: ElementOpsMixin[str], other: Just[int] | np.integer
271+
) -> ElementOpsMixin[str]: ...
272+
273+
@type_check_only
274+
class Supports_ProtoAdd(Protocol[_T_contra, S2]):
275+
def _proto_add(self, other: _T_contra, /) -> ElementOpsMixin[S2]: ...
276+
277+
@type_check_only
278+
class Supports_ProtoRAdd(Protocol[_T_contra, S2]):
279+
def _proto_radd(self, other: _T_contra, /) -> ElementOpsMixin[S2]: ...
280+
281+
@type_check_only
282+
class Supports_ProtoMul(Protocol[_T_contra, S2]):
283+
def _proto_mul(self, other: _T_contra, /) -> ElementOpsMixin[S2]: ...
284+
285+
@type_check_only
286+
class Supports_ProtoRMul(Protocol[_T_contra, S2]):
287+
def _proto_rmul(self, other: _T_contra, /) -> ElementOpsMixin[S2]: ...

pandas-stubs/core/indexes/base.pyi

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ from _typeshed import (
2424
SupportsMul,
2525
SupportsRAdd,
2626
SupportsRMul,
27+
_T_contra,
2728
)
2829
import numpy as np
2930
from pandas import (
@@ -39,8 +40,13 @@ from pandas import (
3940
TimedeltaIndex,
4041
)
4142
from pandas.core.base import (
43+
ElementOpsMixin,
4244
IndexOpsMixin,
4345
NumListLike,
46+
Supports_ProtoAdd,
47+
Supports_ProtoMul,
48+
Supports_ProtoRAdd,
49+
Supports_ProtoRMul,
4450
_ListLike,
4551
)
4652
from pandas.core.indexes.category import CategoricalIndex
@@ -55,12 +61,10 @@ from pandas._libs.tslibs.timedeltas import Timedelta
5561
from pandas._typing import (
5662
C2,
5763
S1,
58-
S1_CO,
59-
S1_CT,
60-
S2_CO_NSDT,
64+
S2,
6165
S2_CT,
66+
S2_NSDT,
6267
T_COMPLEX,
63-
T_INT,
6468
AnyAll,
6569
ArrayLike,
6670
AxesData,
@@ -99,8 +103,8 @@ from pandas._typing import (
99103

100104
class InvalidIndexError(Exception): ...
101105

102-
class Index(IndexOpsMixin[S1]):
103-
__hash__: ClassVar[None] # type: ignore[assignment]
106+
class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
107+
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]
104108
# overloads with additional dtypes
105109
@overload
106110
def __new__( # pyright: ignore[reportOverlappingOverload]
@@ -506,20 +510,14 @@ class Index(IndexOpsMixin[S1]):
506510
@overload
507511
def __add__(self, other: Index[Never]) -> Index: ...
508512
@overload
509-
def __add__(self: Index[bool], other: bool | Sequence[bool]) -> Index[bool]: ...
510-
@overload
511-
def __add__(self: Index[int], other: bool | Sequence[bool]) -> Index[int]: ...
512-
@overload
513-
def __add__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
514-
@overload
515513
def __add__(
516-
self: Index[complex], other: float | Sequence[float]
517-
) -> Index[complex]: ...
514+
self: Supports_ProtoAdd[_T_contra, S2], other: _T_contra | Sequence[_T_contra]
515+
) -> Index[S2]: ...
518516
@overload
519517
def __add__(
520-
self: Index[S1_CT],
521-
other: SupportsRAdd[S1_CT, S1_CO] | Sequence[SupportsRAdd[S1_CT, S1_CO]],
522-
) -> Index[S1_CO]: ...
518+
self: Index[S2_CT],
519+
other: SupportsRAdd[S2_CT, S2] | Sequence[SupportsRAdd[S2_CT, S2]],
520+
) -> Index[S2]: ...
523521
@overload
524522
def __add__(
525523
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]
@@ -553,27 +551,22 @@ class Index(IndexOpsMixin[S1]):
553551
) -> Never: ...
554552
@overload
555553
def __add__(
556-
self: Index[_str], other: _str | Sequence[_str] | np_ndarray_str | Index[_str]
554+
self: Index[_str], other: np_ndarray_str | Index[_str]
557555
) -> Index[_str]: ...
558556
@overload
559557
def __radd__(self: Index[Never], other: _str) -> Never: ...
560558
@overload
561559
def __radd__(self: Index[Never], other: complex | _ListLike | Index) -> Index: ...
562560
@overload
563-
def __radd__(self: Index[bool], other: bool | Sequence[bool]) -> Index[bool]: ...
564-
@overload
565-
def __radd__(self: Index[int], other: bool | Sequence[bool]) -> Index[int]: ...
566-
@overload
567-
def __radd__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
568-
@overload
569561
def __radd__(
570-
self: Index[complex], other: float | Sequence[float]
571-
) -> Index[complex]: ...
562+
self: Supports_ProtoRAdd[_T_contra, S2],
563+
other: _T_contra | Sequence[_T_contra],
564+
) -> Index[S2]: ...
572565
@overload
573566
def __radd__(
574-
self: Index[S1_CT],
575-
other: SupportsAdd[S1_CT, S1_CO] | Sequence[SupportsAdd[S1_CT, S1_CO]],
576-
) -> Index[S1_CO]: ...
567+
self: Index[S2_CT],
568+
other: SupportsAdd[S2_CT, S2] | Sequence[SupportsAdd[S2_CT, S2]],
569+
) -> Index[S2]: ...
577570
@overload
578571
def __radd__(
579572
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]
@@ -607,7 +600,7 @@ class Index(IndexOpsMixin[S1]):
607600
) -> Never: ...
608601
@overload
609602
def __radd__(
610-
self: Index[_str], other: _str | Sequence[_str] | np_ndarray_str | Index[_str]
603+
self: Index[_str], other: np_ndarray_str | Index[_str]
611604
) -> Index[_str]: ...
612605
@overload
613606
def __sub__(self: Index[Never], other: DatetimeIndex) -> Never: ...
@@ -773,16 +766,7 @@ class Index(IndexOpsMixin[S1]):
773766
@overload
774767
def __mul__(
775768
self: Index[Timedelta],
776-
other: (
777-
Just[int]
778-
| Just[float]
779-
| Sequence[Just[int]]
780-
| Sequence[Just[float]]
781-
| np_ndarray_anyint
782-
| np_ndarray_float
783-
| Index[int]
784-
| Index[float]
785-
),
769+
other: np_ndarray_anyint | np_ndarray_float | Index[int] | Index[float],
786770
) -> Index[Timedelta]: ...
787771
@overload
788772
def __mul__(
@@ -797,24 +781,17 @@ class Index(IndexOpsMixin[S1]):
797781
) -> Never: ...
798782
@overload
799783
def __mul__(
800-
self: Index[_str],
801-
other: Just[int] | Sequence[Just[int]] | np_ndarray_anyint | Index[int],
784+
self: Index[_str], other: np_ndarray_anyint | Index[int]
802785
) -> Index[_str]: ...
803786
@overload
804-
def __mul__(self: Index[T_INT], other: bool | Sequence[bool]) -> Index[T_INT]: ...
805-
@overload
806-
def __mul__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
807-
@overload
808787
def __mul__(
809-
self: Index[complex], other: float | Sequence[float]
810-
) -> Index[complex]: ...
788+
self: Supports_ProtoMul[_T_contra, S2], other: _T_contra | Sequence[_T_contra]
789+
) -> Index[S2]: ...
811790
@overload
812791
def __mul__(
813792
self: Index[S2_CT],
814-
other: (
815-
SupportsRMul[S2_CT, S2_CO_NSDT] | Sequence[SupportsRMul[S2_CT, S2_CO_NSDT]]
816-
),
817-
) -> Index[S2_CO_NSDT]: ...
793+
other: SupportsRMul[S2_CT, S2_NSDT] | Sequence[SupportsRMul[S2_CT, S2_NSDT]],
794+
) -> Index[S2_NSDT]: ...
818795
@overload
819796
def __mul__(
820797
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]
@@ -865,16 +842,7 @@ class Index(IndexOpsMixin[S1]):
865842
@overload
866843
def __rmul__(
867844
self: Index[Timedelta],
868-
other: (
869-
Just[int]
870-
| Just[float]
871-
| Sequence[Just[int]]
872-
| Sequence[Just[float]]
873-
| np_ndarray_anyint
874-
| np_ndarray_float
875-
| Index[int]
876-
| Index[float]
877-
),
845+
other: np_ndarray_anyint | np_ndarray_float | Index[int] | Index[float],
878846
) -> Index[Timedelta]: ...
879847
@overload
880848
def __rmul__(
@@ -889,24 +857,18 @@ class Index(IndexOpsMixin[S1]):
889857
) -> Never: ...
890858
@overload
891859
def __rmul__(
892-
self: Index[_str],
893-
other: Just[int] | Sequence[Just[int]] | np_ndarray_anyint | Index[int],
860+
self: Index[_str], other: np_ndarray_anyint | Index[int]
894861
) -> Index[_str]: ...
895862
@overload
896-
def __rmul__(self: Index[T_INT], other: bool | Sequence[bool]) -> Index[T_INT]: ...
897-
@overload
898-
def __rmul__(self: Index[float], other: int | Sequence[int]) -> Index[float]: ...
899-
@overload
900863
def __rmul__(
901-
self: Index[complex], other: float | Sequence[float]
902-
) -> Index[complex]: ...
864+
self: Supports_ProtoRMul[_T_contra, S2],
865+
other: _T_contra | Sequence[_T_contra],
866+
) -> Index[S2]: ...
903867
@overload
904868
def __rmul__(
905869
self: Index[S2_CT],
906-
other: (
907-
SupportsMul[S2_CT, S2_CO_NSDT] | Sequence[SupportsMul[S2_CT, S2_CO_NSDT]]
908-
),
909-
) -> Index[S2_CO_NSDT]: ...
870+
other: SupportsMul[S2_CT, S2_NSDT] | Sequence[SupportsMul[S2_CT, S2_NSDT]],
871+
) -> Index[S2_NSDT]: ...
910872
@overload
911873
def __rmul__(
912874
self: Index[T_COMPLEX], other: np_ndarray_bool | Index[bool]

0 commit comments

Comments
 (0)