Skip to content

Commit 87f9e6a

Browse files
committed
TYP: Use Self for type checking (pandas/core/arrays/)
1 parent 2070bb8 commit 87f9e6a

File tree

10 files changed

+142
-209
lines changed

10 files changed

+142
-209
lines changed

pandas/_typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
tzinfo,
77
)
88
from os import PathLike
9+
import sys
910
from typing import (
1011
TYPE_CHECKING,
1112
Any,
@@ -83,8 +84,13 @@
8384
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
8485
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]
8586

87+
if sys.version_info >= (3, 11):
88+
from typing import Self
89+
else:
90+
from typing_extensions import Self # pyright: reportUnusedImport = false
8691
else:
8792
npt: Any = None
93+
Self: Any = None
8894

8995
HashableT = TypeVar("HashableT", bound=Hashable)
9096

pandas/core/arrays/_mixins.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Any,
77
Literal,
88
Sequence,
9-
TypeVar,
109
cast,
1110
overload,
1211
)
@@ -23,11 +22,11 @@
2322
PositionalIndexer2D,
2423
PositionalIndexerTuple,
2524
ScalarIndexer,
25+
Self,
2626
SequenceIndexer,
2727
Shape,
2828
TakeIndexer,
2929
npt,
30-
type_t,
3130
)
3231
from pandas.errors import AbstractMethodError
3332
from pandas.util._decorators import doc
@@ -61,10 +60,6 @@
6160
from pandas.core.indexers import check_array_indexer
6261
from pandas.core.sorting import nargminmax
6362

64-
NDArrayBackedExtensionArrayT = TypeVar(
65-
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
66-
)
67-
6863
if TYPE_CHECKING:
6964
from pandas._typing import (
7065
NumpySorter,
@@ -153,13 +148,13 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
153148
return arr.view(dtype=dtype) # type: ignore[arg-type]
154149

155150
def take(
156-
self: NDArrayBackedExtensionArrayT,
151+
self,
157152
indices: TakeIndexer,
158153
*,
159154
allow_fill: bool = False,
160155
fill_value: Any = None,
161156
axis: AxisInt = 0,
162-
) -> NDArrayBackedExtensionArrayT:
157+
) -> Self:
163158
if allow_fill:
164159
fill_value = self._validate_scalar(fill_value)
165160

@@ -208,17 +203,17 @@ def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri
208203
raise NotImplementedError
209204
return nargminmax(self, "argmax", axis=axis)
210205

211-
def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
206+
def unique(self) -> Self:
212207
new_data = unique(self._ndarray)
213208
return self._from_backing_data(new_data)
214209

215210
@classmethod
216211
@doc(ExtensionArray._concat_same_type)
217212
def _concat_same_type(
218-
cls: type[NDArrayBackedExtensionArrayT],
219-
to_concat: Sequence[NDArrayBackedExtensionArrayT],
213+
cls,
214+
to_concat: Sequence[Self],
220215
axis: AxisInt = 0,
221-
) -> NDArrayBackedExtensionArrayT:
216+
) -> Self:
222217
dtypes = {str(x.dtype) for x in to_concat}
223218
if len(dtypes) != 1:
224219
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
@@ -258,15 +253,15 @@ def __getitem__(self, key: ScalarIndexer) -> Any:
258253

259254
@overload
260255
def __getitem__(
261-
self: NDArrayBackedExtensionArrayT,
256+
self,
262257
key: SequenceIndexer | PositionalIndexerTuple,
263-
) -> NDArrayBackedExtensionArrayT:
258+
) -> Self:
264259
...
265260

266261
def __getitem__(
267-
self: NDArrayBackedExtensionArrayT,
262+
self,
268263
key: PositionalIndexer2D,
269-
) -> NDArrayBackedExtensionArrayT | Any:
264+
) -> Self | Any:
270265
if lib.is_integer(key):
271266
# fast-path
272267
result = self._ndarray[key]
@@ -293,9 +288,7 @@ def _fill_mask_inplace(
293288
func(self._ndarray.T, limit=limit, mask=mask.T)
294289

295290
@doc(ExtensionArray.fillna)
296-
def fillna(
297-
self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
298-
) -> NDArrayBackedExtensionArrayT:
291+
def fillna(self, value=None, method=None, limit=None) -> Self:
299292
value, method = validate_fillna_kwargs(
300293
value, method, validate_scalar_dict_value=False
301294
)
@@ -359,9 +352,7 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
359352

360353
np.putmask(self._ndarray, mask, value)
361354

362-
def _where(
363-
self: NDArrayBackedExtensionArrayT, mask: npt.NDArray[np.bool_], value
364-
) -> NDArrayBackedExtensionArrayT:
355+
def _where(self: Self, mask: npt.NDArray[np.bool_], value) -> Self:
365356
"""
366357
Analogue to np.where(mask, self, value)
367358
@@ -383,9 +374,7 @@ def _where(
383374
# ------------------------------------------------------------------------
384375
# Index compat methods
385376

386-
def insert(
387-
self: NDArrayBackedExtensionArrayT, loc: int, item
388-
) -> NDArrayBackedExtensionArrayT:
377+
def insert(self, loc: int, item) -> Self:
389378
"""
390379
Make new ExtensionArray inserting new item at location. Follows
391380
Python list.append semantics for negative values.
@@ -451,10 +440,10 @@ def value_counts(self, dropna: bool = True) -> Series:
451440
return Series(result._values, index=index, name=result.name)
452441

453442
def _quantile(
454-
self: NDArrayBackedExtensionArrayT,
443+
self,
455444
qs: npt.NDArray[np.float64],
456445
interpolation: str,
457-
) -> NDArrayBackedExtensionArrayT:
446+
) -> Self:
458447
# TODO: disable for Categorical if not ordered?
459448

460449
mask = np.asarray(self.isna())
@@ -478,9 +467,7 @@ def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
478467
# numpy-like methods
479468

480469
@classmethod
481-
def _empty(
482-
cls: type_t[NDArrayBackedExtensionArrayT], shape: Shape, dtype: ExtensionDtype
483-
) -> NDArrayBackedExtensionArrayT:
470+
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:
484471
"""
485472
Analogous to np.empty(shape, dtype=dtype)
486473

pandas/core/arrays/arrow/array.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Callable,
1010
Literal,
1111
Sequence,
12-
TypeVar,
1312
cast,
1413
)
1514

@@ -25,6 +24,7 @@
2524
NpDtype,
2625
PositionalIndexer,
2726
Scalar,
27+
Self,
2828
SortKind,
2929
TakeIndexer,
3030
TimeAmbiguous,
@@ -138,8 +138,6 @@ def floordiv_compat(
138138

139139
from pandas import Series
140140

141-
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
142-
143141

144142
def get_unit_from_pa_dtype(pa_dtype):
145143
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
@@ -405,16 +403,16 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
405403
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
406404
return self.to_numpy(dtype=dtype)
407405

408-
def __invert__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
406+
def __invert__(self) -> Self:
409407
return type(self)(pc.invert(self._data))
410408

411-
def __neg__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
409+
def __neg__(self) -> Self:
412410
return type(self)(pc.negate_checked(self._data))
413411

414-
def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
412+
def __pos__(self) -> Self:
415413
return type(self)(self._data)
416414

417-
def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
415+
def __abs__(self) -> Self:
418416
return type(self)(pc.abs_checked(self._data))
419417

420418
# GH 42600: __getstate__/__setstate__ not necessary once
@@ -606,7 +604,7 @@ def argmin(self, skipna: bool = True) -> int:
606604
def argmax(self, skipna: bool = True) -> int:
607605
return self._argmin_max(skipna, "max")
608606

609-
def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
607+
def copy(self) -> Self:
610608
"""
611609
Return a shallow copy of the array.
612610
@@ -618,7 +616,7 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
618616
"""
619617
return type(self)(self._data)
620618

621-
def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
619+
def dropna(self) -> Self:
622620
"""
623621
Return ArrowExtensionArray without NA values.
624622
@@ -630,11 +628,11 @@ def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
630628

631629
@doc(ExtensionArray.fillna)
632630
def fillna(
633-
self: ArrowExtensionArrayT,
631+
self,
634632
value: object | ArrayLike | None = None,
635633
method: FillnaOptions | None = None,
636634
limit: int | None = None,
637-
) -> ArrowExtensionArrayT:
635+
) -> Self:
638636
value, method = validate_fillna_kwargs(value, method)
639637

640638
if limit is not None:
@@ -751,9 +749,7 @@ def reshape(self, *args, **kwargs):
751749
f"as backed by a 1D pyarrow.ChunkedArray."
752750
)
753751

754-
def round(
755-
self: ArrowExtensionArrayT, decimals: int = 0, *args, **kwargs
756-
) -> ArrowExtensionArrayT:
752+
def round(self, decimals: int = 0, *args, **kwargs) -> Self:
757753
"""
758754
Round each value in the array a to the given number of decimals.
759755
@@ -926,7 +922,7 @@ def to_numpy(
926922
result[self.isna()] = na_value
927923
return result
928924

929-
def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
925+
def unique(self) -> Self:
930926
"""
931927
Compute the ArrowExtensionArray of unique values.
932928
@@ -998,9 +994,7 @@ def value_counts(self, dropna: bool = True) -> Series:
998994
return Series(counts, index=index, name="count").astype("Int64")
999995

1000996
@classmethod
1001-
def _concat_same_type(
1002-
cls: type[ArrowExtensionArrayT], to_concat
1003-
) -> ArrowExtensionArrayT:
997+
def _concat_same_type(cls, to_concat) -> Self:
1004998
"""
1005999
Concatenate multiple ArrowExtensionArrays.
10061000
@@ -1321,9 +1315,7 @@ def _rank(
13211315

13221316
return type(self)(result)
13231317

1324-
def _quantile(
1325-
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
1326-
) -> ArrowExtensionArrayT:
1318+
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
13271319
"""
13281320
Compute the quantiles of self for each quantile in `qs`.
13291321
@@ -1360,7 +1352,7 @@ def _quantile(
13601352

13611353
return type(self)(result)
13621354

1363-
def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:
1355+
def _mode(self, dropna: bool = True) -> Self:
13641356
"""
13651357
Returns the mode(s) of the ExtensionArray.
13661358

0 commit comments

Comments
 (0)