Skip to content

Commit 2ab12b6

Browse files
committed
TYP: Use Self for type checking (pandas/core/internals/)
1 parent 2070bb8 commit 2ab12b6

File tree

3 files changed

+76
-83
lines changed

3 files changed

+76
-83
lines changed

pandas/core/internals/array_manager.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Callable,
99
Hashable,
1010
Literal,
11-
TypeVar,
1211
)
1312

1413
import numpy as np
@@ -23,6 +22,7 @@
2322
AxisInt,
2423
DtypeObj,
2524
QuantileInterpolation,
25+
Self,
2626
npt,
2727
)
2828
from pandas.util._validators import validate_bool_kwarg
@@ -93,8 +93,6 @@
9393
to_native_types,
9494
)
9595

96-
T = TypeVar("T", bound="BaseArrayManager")
97-
9896

9997
class BaseArrayManager(DataManager):
10098
"""
@@ -129,7 +127,7 @@ def __init__(
129127
) -> None:
130128
raise NotImplementedError
131129

132-
def make_empty(self: T, axes=None) -> T:
130+
def make_empty(self, axes=None) -> Self:
133131
"""Return an empty ArrayManager with the items axis of len 0 (no columns)"""
134132
if axes is None:
135133
axes = [self.axes[1:], Index([])]
@@ -193,11 +191,11 @@ def __repr__(self) -> str:
193191
return output
194192

195193
def apply(
196-
self: T,
194+
self,
197195
f,
198196
align_keys: list[str] | None = None,
199197
**kwargs,
200-
) -> T:
198+
) -> Self:
201199
"""
202200
Iterate over the arrays, collect and create a new ArrayManager.
203201
@@ -255,8 +253,8 @@ def apply(
255253
return type(self)(result_arrays, new_axes) # type: ignore[arg-type]
256254

257255
def apply_with_block(
258-
self: T, f, align_keys=None, swap_axis: bool = True, **kwargs
259-
) -> T:
256+
self, f, align_keys=None, swap_axis: bool = True, **kwargs
257+
) -> Self:
260258
# switch axis to follow BlockManager logic
261259
if swap_axis and "axis" in kwargs and self.ndim == 2:
262260
kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0
@@ -309,7 +307,7 @@ def apply_with_block(
309307

310308
return type(self)(result_arrays, self._axes)
311309

312-
def where(self: T, other, cond, align: bool) -> T:
310+
def where(self, other, cond, align: bool) -> Self:
313311
if align:
314312
align_keys = ["other", "cond"]
315313
else:
@@ -323,10 +321,10 @@ def where(self: T, other, cond, align: bool) -> T:
323321
cond=cond,
324322
)
325323

326-
def setitem(self: T, indexer, value) -> T:
324+
def setitem(self, indexer, value) -> Self:
327325
return self.apply_with_block("setitem", indexer=indexer, value=value)
328326

329-
def putmask(self: T, mask, new, align: bool = True) -> T:
327+
def putmask(self, mask, new, align: bool = True) -> Self:
330328
if align:
331329
align_keys = ["new", "mask"]
332330
else:
@@ -340,14 +338,14 @@ def putmask(self: T, mask, new, align: bool = True) -> T:
340338
new=new,
341339
)
342340

343-
def diff(self: T, n: int, axis: AxisInt) -> T:
341+
def diff(self, n: int, axis: AxisInt) -> Self:
344342
assert self.ndim == 2 and axis == 0 # caller ensures
345343
return self.apply(algos.diff, n=n, axis=axis)
346344

347-
def interpolate(self: T, **kwargs) -> T:
345+
def interpolate(self, **kwargs) -> Self:
348346
return self.apply_with_block("interpolate", swap_axis=False, **kwargs)
349347

350-
def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
348+
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
351349
if fill_value is lib.no_default:
352350
fill_value = None
353351

@@ -359,7 +357,7 @@ def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
359357
"shift", periods=periods, axis=axis, fill_value=fill_value
360358
)
361359

362-
def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
360+
def fillna(self, value, limit, inplace: bool, downcast) -> Self:
363361
if limit is not None:
364362
# Do this validation even if we go through one of the no-op paths
365363
limit = libalgos.validate_limit(None, limit=limit)
@@ -368,13 +366,13 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
368366
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
369367
)
370368

371-
def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:
369+
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
372370
if copy is None:
373371
copy = True
374372

375373
return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)
376374

377-
def convert(self: T, copy: bool | None) -> T:
375+
def convert(self, copy: bool | None) -> Self:
378376
if copy is None:
379377
copy = True
380378

@@ -397,10 +395,10 @@ def _convert(arr):
397395

398396
return self.apply(_convert)
399397

400-
def replace_regex(self: T, **kwargs) -> T:
398+
def replace_regex(self, **kwargs) -> Self:
401399
return self.apply_with_block("_replace_regex", **kwargs)
402400

403-
def replace(self: T, to_replace, value, inplace: bool) -> T:
401+
def replace(self, to_replace, value, inplace: bool) -> Self:
404402
inplace = validate_bool_kwarg(inplace, "inplace")
405403
assert np.ndim(value) == 0, value
406404
# TODO "replace" is right now implemented on the blocks, we should move
@@ -410,12 +408,12 @@ def replace(self: T, to_replace, value, inplace: bool) -> T:
410408
)
411409

412410
def replace_list(
413-
self: T,
411+
self,
414412
src_list: list[Any],
415413
dest_list: list[Any],
416414
inplace: bool = False,
417415
regex: bool = False,
418-
) -> T:
416+
) -> Self:
419417
"""do a list replace"""
420418
inplace = validate_bool_kwarg(inplace, "inplace")
421419

@@ -427,7 +425,7 @@ def replace_list(
427425
regex=regex,
428426
)
429427

430-
def to_native_types(self: T, **kwargs) -> T:
428+
def to_native_types(self, **kwargs) -> Self:
431429
return self.apply(to_native_types, **kwargs)
432430

433431
@property
@@ -453,7 +451,7 @@ def is_view(self) -> bool:
453451
def is_single_block(self) -> bool:
454452
return len(self.arrays) == 1
455453

456-
def _get_data_subset(self: T, predicate: Callable) -> T:
454+
def _get_data_subset(self, predicate: Callable) -> Self:
457455
indices = [i for i, arr in enumerate(self.arrays) if predicate(arr)]
458456
arrays = [self.arrays[i] for i in indices]
459457
# TODO copy?
@@ -464,7 +462,7 @@ def _get_data_subset(self: T, predicate: Callable) -> T:
464462
new_axes = [self._axes[0], new_cols]
465463
return type(self)(arrays, new_axes, verify_integrity=False)
466464

467-
def get_bool_data(self: T, copy: bool = False) -> T:
465+
def get_bool_data(self, copy: bool = False) -> Self:
468466
"""
469467
Select columns that are bool-dtype and object-dtype columns that are all-bool.
470468
@@ -475,7 +473,7 @@ def get_bool_data(self: T, copy: bool = False) -> T:
475473
"""
476474
return self._get_data_subset(lambda x: x.dtype == np.dtype(bool))
477475

478-
def get_numeric_data(self: T, copy: bool = False) -> T:
476+
def get_numeric_data(self, copy: bool = False) -> Self:
479477
"""
480478
Select columns that have a numeric dtype.
481479
@@ -489,7 +487,7 @@ def get_numeric_data(self: T, copy: bool = False) -> T:
489487
or getattr(arr.dtype, "_is_numeric", False)
490488
)
491489

492-
def copy(self: T, deep: bool | Literal["all"] | None = True) -> T:
490+
def copy(self, deep: bool | Literal["all"] | None = True) -> Self:
493491
"""
494492
Make deep or shallow copy of ArrayManager
495493
@@ -526,7 +524,7 @@ def copy_func(ax):
526524
return type(self)(new_arrays, new_axes, verify_integrity=False)
527525

528526
def reindex_indexer(
529-
self: T,
527+
self,
530528
new_axis,
531529
indexer,
532530
axis: AxisInt,
@@ -537,7 +535,7 @@ def reindex_indexer(
537535
only_slice: bool = False,
538536
# ArrayManager specific keywords
539537
use_na_proxy: bool = False,
540-
) -> T:
538+
) -> Self:
541539
axis = self._normalize_axis(axis)
542540
return self._reindex_indexer(
543541
new_axis,
@@ -550,15 +548,15 @@ def reindex_indexer(
550548
)
551549

552550
def _reindex_indexer(
553-
self: T,
551+
self,
554552
new_axis,
555553
indexer: npt.NDArray[np.intp] | None,
556554
axis: AxisInt,
557555
fill_value=None,
558556
allow_dups: bool = False,
559557
copy: bool | None = True,
560558
use_na_proxy: bool = False,
561-
) -> T:
559+
) -> Self:
562560
"""
563561
Parameters
564562
----------
@@ -629,12 +627,12 @@ def _reindex_indexer(
629627
return type(self)(new_arrays, new_axes, verify_integrity=False)
630628

631629
def take(
632-
self: T,
630+
self,
633631
indexer,
634632
axis: AxisInt = 1,
635633
verify: bool = True,
636634
convert_indices: bool = True,
637-
) -> T:
635+
) -> Self:
638636
"""
639637
Take items along any axis.
640638
"""
@@ -926,7 +924,7 @@ def idelete(self, indexer) -> ArrayManager:
926924
# --------------------------------------------------------------------
927925
# Array-wise Operation
928926

929-
def grouped_reduce(self: T, func: Callable) -> T:
927+
def grouped_reduce(self, func: Callable) -> Self:
930928
"""
931929
Apply grouped reduction function columnwise, returning a new ArrayManager.
932930
@@ -965,7 +963,7 @@ def grouped_reduce(self: T, func: Callable) -> T:
965963
# expected "List[Union[ndarray, ExtensionArray]]"
966964
return type(self)(result_arrays, [index, columns]) # type: ignore[arg-type]
967965

968-
def reduce(self: T, func: Callable) -> T:
966+
def reduce(self, func: Callable) -> Self:
969967
"""
970968
Apply reduction function column-wise, returning a single-row ArrayManager.
971969

pandas/core/internals/base.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import (
88
Literal,
9-
TypeVar,
109
final,
1110
)
1211

@@ -16,6 +15,7 @@
1615
ArrayLike,
1716
AxisInt,
1817
DtypeObj,
18+
Self,
1919
Shape,
2020
)
2121
from pandas.errors import AbstractMethodError
@@ -31,8 +31,6 @@
3131
default_index,
3232
)
3333

34-
T = TypeVar("T", bound="DataManager")
35-
3634

3735
class DataManager(PandasObject):
3836
# TODO share more methods/attributes
@@ -73,25 +71,25 @@ def _validate_set_axis(self, axis: AxisInt, new_labels: Index) -> None:
7371
)
7472

7573
def reindex_indexer(
76-
self: T,
74+
self,
7775
new_axis,
7876
indexer,
7977
axis: AxisInt,
8078
fill_value=None,
8179
allow_dups: bool = False,
8280
copy: bool = True,
8381
only_slice: bool = False,
84-
) -> T:
82+
) -> Self:
8583
raise AbstractMethodError(self)
8684

8785
@final
8886
def reindex_axis(
89-
self: T,
87+
self,
9088
new_index: Index,
9189
axis: AxisInt,
9290
fill_value=None,
9391
only_slice: bool = False,
94-
) -> T:
92+
) -> Self:
9593
"""
9694
Conform data manager to new index.
9795
"""
@@ -106,7 +104,7 @@ def reindex_axis(
106104
only_slice=only_slice,
107105
)
108106

109-
def _equal_values(self: T, other: T) -> bool:
107+
def _equal_values(self, other: Self) -> bool:
110108
"""
111109
To be implemented by the subclasses. Only check the column values
112110
assuming shape and indexes have already been checked.
@@ -130,15 +128,15 @@ def equals(self, other: object) -> bool:
130128
return self._equal_values(other)
131129

132130
def apply(
133-
self: T,
131+
self,
134132
f,
135133
align_keys: list[str] | None = None,
136134
**kwargs,
137-
) -> T:
135+
) -> Self:
138136
raise AbstractMethodError(self)
139137

140138
@final
141-
def isna(self: T, func) -> T:
139+
def isna(self, func) -> Self:
142140
return self.apply("apply", func=func)
143141

144142
# --------------------------------------------------------------------
@@ -147,7 +145,7 @@ def isna(self: T, func) -> T:
147145
def is_consolidated(self) -> bool:
148146
return True
149147

150-
def consolidate(self: T) -> T:
148+
def consolidate(self) -> Self:
151149
return self
152150

153151
def _consolidate_inplace(self) -> None:

0 commit comments

Comments
 (0)