Skip to content

Commit 05424c2

Browse files
committed
TYP: Use Self for type checking (pandas/core/internals/)
1 parent d082266 commit 05424c2

File tree

4 files changed

+124
-83
lines changed

4 files changed

+124
-83
lines changed

.startup.ipy

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
>>> from itertools import product
2+
>>> import numpy as np
3+
>>> import pandas as pd
4+
>>> from pandas.core.reshape.concat import _Concatenator
5+
>>>
6+
>>> def manual_concat(df_list: list[pd.DataFrame]) -> pd.DataFrame:
7+
... columns = [col for df in df_list for col in df.columns]
8+
... columns = list(dict.fromkeys(columns))
9+
... index = np.hstack([df.index.values for df in df_list])
10+
... df_list = [df.reindex(columns=columns) for df in df_list]
11+
... values = np.vstack([df.values for df in df_list])
12+
... return pd.DataFrame(values, index=index, columns=columns, dtype=df_list[0].dtypes[0])
13+
>>>
14+
>>> def compare_frames(df_list: list[pd.DataFrame]) -> None:
15+
... concat_df = pd.concat(df_list)
16+
... manual_df = manual_concat(df_list)
17+
... if not concat_df.equals(manual_df):
18+
... raise ValueError("different concatenations!")
19+
>>>
20+
>>> def make_dataframes(num_dfs, num_idx, num_cols, dtype=np.int32, drop_column=False) -> list[pd.DataFrame]:
21+
... values = np.random.randint(-100, 100, size=[num_idx, num_cols])
22+
... index = [f"i{i}" for i in range(num_idx)]
23+
... columns = np.random.choice([f"c{i}" for i in range(num_cols)], num_cols, replace=False)
24+
... df = pd.DataFrame(values, index=index, columns=columns, dtype=dtype)
25+
...
26+
... df_list = []
27+
... for i in range(num_dfs):
28+
... new_df = df.copy()
29+
... if drop_column:
30+
... label = new_df.columns[i]
31+
... new_df = new_df.drop(label, axis=1)
32+
... df_list.append(new_df)
33+
... return df_list
34+
>>>
35+
>>> test_data = [ # num_idx, num_cols, num_dfs
36+
... [100, 1_000, 3],
37+
... ]
38+
>>> for i, (num_idx, num_cols, num_dfs) in enumerate(test_data):
39+
... print(f"\n{i}: {num_dfs=}, {num_idx=}, {num_cols=}")
40+
... df_list = make_dataframes(num_dfs, num_idx, num_cols, drop_column=False)
41+
... df_list_dropped = make_dataframes(num_dfs, num_idx, num_cols, drop_column=True)
42+
... print("manual:")
43+
... %timeit manual_concat(df_list)
44+
... compare_frames(df_list)
45+
... for use_dropped in [False, True]:
46+
... print(f"pd.concat: {use_dropped=}")
47+
... this_df_list = df_list if not use_dropped else df_list_dropped
48+
... %timeit pd.concat(this_df_list)

pandas/core/internals/array_manager.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Callable,
99
Hashable,
1010
Literal,
11-
TypeVar,
1211
)
1312

1413
import numpy as np
@@ -23,6 +22,7 @@
2322
AxisInt,
2423
DtypeObj,
2524
QuantileInterpolation,
25+
Self,
2626
npt,
2727
)
2828
from pandas.util._validators import validate_bool_kwarg
@@ -93,8 +93,6 @@
9393
to_native_types,
9494
)
9595

96-
T = TypeVar("T", bound="BaseArrayManager")
97-
9896

9997
class BaseArrayManager(DataManager):
10098
"""
@@ -129,7 +127,7 @@ def __init__(
129127
) -> None:
130128
raise NotImplementedError
131129

132-
def make_empty(self: T, axes=None) -> T:
130+
def make_empty(self, axes=None) -> Self:
133131
"""Return an empty ArrayManager with the items axis of len 0 (no columns)"""
134132
if axes is None:
135133
axes = [self.axes[1:], Index([])]
@@ -193,11 +191,11 @@ def __repr__(self) -> str:
193191
return output
194192

195193
def apply(
196-
self: T,
194+
self,
197195
f,
198196
align_keys: list[str] | None = None,
199197
**kwargs,
200-
) -> T:
198+
) -> Self:
201199
"""
202200
Iterate over the arrays, collect and create a new ArrayManager.
203201
@@ -255,8 +253,8 @@ def apply(
255253
return type(self)(result_arrays, new_axes) # type: ignore[arg-type]
256254

257255
def apply_with_block(
258-
self: T, f, align_keys=None, swap_axis: bool = True, **kwargs
259-
) -> T:
256+
self, f, align_keys=None, swap_axis: bool = True, **kwargs
257+
) -> Self:
260258
# switch axis to follow BlockManager logic
261259
if swap_axis and "axis" in kwargs and self.ndim == 2:
262260
kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0
@@ -309,7 +307,7 @@ def apply_with_block(
309307

310308
return type(self)(result_arrays, self._axes)
311309

312-
def where(self: T, other, cond, align: bool) -> T:
310+
def where(self, other, cond, align: bool) -> Self:
313311
if align:
314312
align_keys = ["other", "cond"]
315313
else:
@@ -323,10 +321,10 @@ def where(self: T, other, cond, align: bool) -> T:
323321
cond=cond,
324322
)
325323

326-
def setitem(self: T, indexer, value) -> T:
324+
def setitem(self, indexer, value) -> Self:
327325
return self.apply_with_block("setitem", indexer=indexer, value=value)
328326

329-
def putmask(self: T, mask, new, align: bool = True) -> T:
327+
def putmask(self, mask, new, align: bool = True) -> Self:
330328
if align:
331329
align_keys = ["new", "mask"]
332330
else:
@@ -340,14 +338,14 @@ def putmask(self: T, mask, new, align: bool = True) -> T:
340338
new=new,
341339
)
342340

343-
def diff(self: T, n: int, axis: AxisInt) -> T:
341+
def diff(self, n: int, axis: AxisInt) -> Self:
344342
assert self.ndim == 2 and axis == 0 # caller ensures
345343
return self.apply(algos.diff, n=n, axis=axis)
346344

347-
def interpolate(self: T, **kwargs) -> T:
345+
def interpolate(self, **kwargs) -> Self:
348346
return self.apply_with_block("interpolate", swap_axis=False, **kwargs)
349347

350-
def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
348+
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
351349
if fill_value is lib.no_default:
352350
fill_value = None
353351

@@ -359,7 +357,7 @@ def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
359357
"shift", periods=periods, axis=axis, fill_value=fill_value
360358
)
361359

362-
def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
360+
def fillna(self, value, limit, inplace: bool, downcast) -> Self:
363361
if limit is not None:
364362
# Do this validation even if we go through one of the no-op paths
365363
limit = libalgos.validate_limit(None, limit=limit)
@@ -368,13 +366,13 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
368366
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
369367
)
370368

371-
def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:
369+
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
372370
if copy is None:
373371
copy = True
374372

375373
return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)
376374

377-
def convert(self: T, copy: bool | None) -> T:
375+
def convert(self, copy: bool | None) -> Self:
378376
if copy is None:
379377
copy = True
380378

@@ -397,10 +395,10 @@ def _convert(arr):
397395

398396
return self.apply(_convert)
399397

400-
def replace_regex(self: T, **kwargs) -> T:
398+
def replace_regex(self, **kwargs) -> Self:
401399
return self.apply_with_block("_replace_regex", **kwargs)
402400

403-
def replace(self: T, to_replace, value, inplace: bool) -> T:
401+
def replace(self, to_replace, value, inplace: bool) -> Self:
404402
inplace = validate_bool_kwarg(inplace, "inplace")
405403
assert np.ndim(value) == 0, value
406404
# TODO "replace" is right now implemented on the blocks, we should move
@@ -410,12 +408,12 @@ def replace(self: T, to_replace, value, inplace: bool) -> T:
410408
)
411409

412410
def replace_list(
413-
self: T,
411+
self,
414412
src_list: list[Any],
415413
dest_list: list[Any],
416414
inplace: bool = False,
417415
regex: bool = False,
418-
) -> T:
416+
) -> Self:
419417
"""do a list replace"""
420418
inplace = validate_bool_kwarg(inplace, "inplace")
421419

@@ -427,7 +425,7 @@ def replace_list(
427425
regex=regex,
428426
)
429427

430-
def to_native_types(self: T, **kwargs) -> T:
428+
def to_native_types(self, **kwargs) -> Self:
431429
return self.apply(to_native_types, **kwargs)
432430

433431
@property
@@ -453,7 +451,7 @@ def is_view(self) -> bool:
453451
def is_single_block(self) -> bool:
454452
return len(self.arrays) == 1
455453

456-
def _get_data_subset(self: T, predicate: Callable) -> T:
454+
def _get_data_subset(self, predicate: Callable) -> Self:
457455
indices = [i for i, arr in enumerate(self.arrays) if predicate(arr)]
458456
arrays = [self.arrays[i] for i in indices]
459457
# TODO copy?
@@ -464,7 +462,7 @@ def _get_data_subset(self: T, predicate: Callable) -> T:
464462
new_axes = [self._axes[0], new_cols]
465463
return type(self)(arrays, new_axes, verify_integrity=False)
466464

467-
def get_bool_data(self: T, copy: bool = False) -> T:
465+
def get_bool_data(self, copy: bool = False) -> Self:
468466
"""
469467
Select columns that are bool-dtype and object-dtype columns that are all-bool.
470468
@@ -475,7 +473,7 @@ def get_bool_data(self: T, copy: bool = False) -> T:
475473
"""
476474
return self._get_data_subset(lambda x: x.dtype == np.dtype(bool))
477475

478-
def get_numeric_data(self: T, copy: bool = False) -> T:
476+
def get_numeric_data(self, copy: bool = False) -> Self:
479477
"""
480478
Select columns that have a numeric dtype.
481479
@@ -489,7 +487,7 @@ def get_numeric_data(self: T, copy: bool = False) -> T:
489487
or getattr(arr.dtype, "_is_numeric", False)
490488
)
491489

492-
def copy(self: T, deep: bool | Literal["all"] | None = True) -> T:
490+
def copy(self, deep: bool | Literal["all"] | None = True) -> Self:
493491
"""
494492
Make deep or shallow copy of ArrayManager
495493
@@ -526,7 +524,7 @@ def copy_func(ax):
526524
return type(self)(new_arrays, new_axes, verify_integrity=False)
527525

528526
def reindex_indexer(
529-
self: T,
527+
self,
530528
new_axis,
531529
indexer,
532530
axis: AxisInt,
@@ -537,7 +535,7 @@ def reindex_indexer(
537535
only_slice: bool = False,
538536
# ArrayManager specific keywords
539537
use_na_proxy: bool = False,
540-
) -> T:
538+
) -> Self:
541539
axis = self._normalize_axis(axis)
542540
return self._reindex_indexer(
543541
new_axis,
@@ -550,15 +548,15 @@ def reindex_indexer(
550548
)
551549

552550
def _reindex_indexer(
553-
self: T,
551+
self,
554552
new_axis,
555553
indexer: npt.NDArray[np.intp] | None,
556554
axis: AxisInt,
557555
fill_value=None,
558556
allow_dups: bool = False,
559557
copy: bool | None = True,
560558
use_na_proxy: bool = False,
561-
) -> T:
559+
) -> Self:
562560
"""
563561
Parameters
564562
----------
@@ -629,12 +627,12 @@ def _reindex_indexer(
629627
return type(self)(new_arrays, new_axes, verify_integrity=False)
630628

631629
def take(
632-
self: T,
630+
self,
633631
indexer: npt.NDArray[np.intp],
634632
axis: AxisInt = 1,
635633
verify: bool = True,
636634
convert_indices: bool = True,
637-
) -> T:
635+
) -> Self:
638636
"""
639637
Take items along any axis.
640638
"""
@@ -923,7 +921,7 @@ def idelete(self, indexer) -> ArrayManager:
923921
# --------------------------------------------------------------------
924922
# Array-wise Operation
925923

926-
def grouped_reduce(self: T, func: Callable) -> T:
924+
def grouped_reduce(self, func: Callable) -> Self:
927925
"""
928926
Apply grouped reduction function columnwise, returning a new ArrayManager.
929927
@@ -962,7 +960,7 @@ def grouped_reduce(self: T, func: Callable) -> T:
962960
# expected "List[Union[ndarray, ExtensionArray]]"
963961
return type(self)(result_arrays, [index, columns]) # type: ignore[arg-type]
964962

965-
def reduce(self: T, func: Callable) -> T:
963+
def reduce(self, func: Callable) -> Self:
966964
"""
967965
Apply reduction function column-wise, returning a single-row ArrayManager.
968966

0 commit comments

Comments
 (0)