Skip to content

Commit 7978238

Browse files
authored
Fix Series#get and DataFrame#get (#867)
* Remove NDFrame#get method * Add get method to Series * Add get method to DataFrame * Add test cases for {DataFrame,Series}#get(..., default=None) would return None * Remove `= ...` from {DataFrame,Series}#get for cases where default parameter is given * Use _typing.T instead of locally defined type var
1 parent 40d9636 commit 7978238

File tree

5 files changed

+72
-3
lines changed

5 files changed

+72
-3
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ from pandas._typing import (
114114
StorageOptions,
115115
StrLike,
116116
Suffixes,
117+
T as _T,
117118
TimestampConvention,
118119
ValidationOptions,
119120
WriteBuffer,
@@ -1696,7 +1697,14 @@ class DataFrame(NDFrame, OpsMixin):
16961697
# def from_dict
16971698
# def from_records
16981699
def ge(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
1699-
# def get
1700+
@overload
1701+
def get(self, key: Hashable, default: None = ...) -> Series | None: ...
1702+
@overload
1703+
def get(self, key: Hashable, default: _T) -> Series | _T: ...
1704+
@overload
1705+
def get(self, key: list[Hashable], default: None = ...) -> DataFrame | None: ...
1706+
@overload
1707+
def get(self, key: list[Hashable], default: _T) -> DataFrame | _T: ...
17001708
def gt(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
17011709
def head(self, n: int = ...) -> DataFrame: ...
17021710
def infer_objects(self) -> DataFrame: ...

pandas-stubs/core/generic.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ from pandas._typing import (
3434
AxisIndex,
3535
CompressionOptions,
3636
CSVQuoting,
37-
Dtype,
3837
DtypeArg,
3938
DtypeBackend,
4039
FilePath,
@@ -299,7 +298,6 @@ class NDFrame(indexing.IndexingMixin):
299298
self, indices, axis=..., is_copy: _bool | None = ..., **kwargs
300299
) -> Self: ...
301300
def __delitem__(self, idx: Hashable) -> None: ...
302-
def get(self, key: object, default: Dtype | None = ...) -> Dtype: ...
303301
def reindex_like(
304302
self,
305303
other,

pandas-stubs/core/series.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ from pandas._typing import (
141141
SortKind,
142142
StrDtypeArg,
143143
StrLike,
144+
T,
144145
TimedeltaDtypeArg,
145146
TimestampConvention,
146147
TimestampDtypeArg,
@@ -381,6 +382,12 @@ class Series(IndexOpsMixin[S1], NDFrame):
381382
@overload
382383
def __getitem__(self, idx: Scalar) -> S1: ...
383384
def __setitem__(self, key, value) -> None: ...
385+
@overload
386+
def get(self, key: Hashable, default: None = ...) -> S1 | None: ...
387+
@overload
388+
def get(self, key: Hashable, default: S1) -> S1: ...
389+
@overload
390+
def get(self, key: Hashable, default: T) -> S1 | T: ...
384391
def repeat(
385392
self, repeats: int | list[int], axis: AxisIndex | None = ...
386393
) -> Series[S1]: ...

tests/test_frame.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,3 +3107,38 @@ def test_itertuples() -> None:
31073107
for item in df.itertuples():
31083108
check(assert_type(item, _PandasNamedTuple), tuple)
31093109
assert_type(item.a, Scalar)
3110+
3111+
3112+
def test_get() -> None:
3113+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
3114+
3115+
# Get single column
3116+
check(assert_type(df.get("a"), Union[pd.Series, None]), pd.Series, np.int64)
3117+
check(assert_type(df.get("z"), Union[pd.Series, None]), type(None))
3118+
check(
3119+
assert_type(df.get("a", default=None), Union[pd.Series, None]),
3120+
pd.Series,
3121+
np.int64,
3122+
)
3123+
check(assert_type(df.get("z", default=None), Union[pd.Series, None]), type(None))
3124+
check(
3125+
assert_type(df.get("a", default=1), Union[pd.Series, int]), pd.Series, np.int64
3126+
)
3127+
check(assert_type(df.get("z", default=1), Union[pd.Series, int]), int)
3128+
3129+
# Get multiple columns
3130+
check(assert_type(df.get(["a"]), Union[pd.DataFrame, None]), pd.DataFrame)
3131+
check(assert_type(df.get(["a", "b"]), Union[pd.DataFrame, None]), pd.DataFrame)
3132+
check(assert_type(df.get(["z"]), Union[pd.DataFrame, None]), type(None))
3133+
check(
3134+
assert_type(df.get(["a", "b"], default=None), Union[pd.DataFrame, None]),
3135+
pd.DataFrame,
3136+
)
3137+
check(
3138+
assert_type(df.get(["z"], default=None), Union[pd.DataFrame, None]), type(None)
3139+
)
3140+
check(
3141+
assert_type(df.get(["a", "b"], default=1), Union[pd.DataFrame, int]),
3142+
pd.DataFrame,
3143+
)
3144+
check(assert_type(df.get(["z"], default=1), Union[pd.DataFrame, int]), int)

tests/test_series.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Any,
1717
Generic,
1818
TypeVar,
19+
Union,
1920
cast,
2021
)
2122

@@ -2875,6 +2876,26 @@ def test_round() -> None:
28752876
check(assert_type(round(pd.Series([1], dtype=int)), "pd.Series[int]"), pd.Series)
28762877

28772878

2879+
def test_get() -> None:
2880+
s_int = pd.Series([1, 2, 3], index=[1, 2, 3])
2881+
2882+
check(assert_type(s_int.get(1), Union[int, None]), np.int64)
2883+
check(assert_type(s_int.get(99), Union[int, None]), type(None))
2884+
check(assert_type(s_int.get(1, default=None), Union[int, None]), np.int64)
2885+
check(assert_type(s_int.get(99, default=None), Union[int, None]), type(None))
2886+
check(assert_type(s_int.get(1, default=2), int), np.int64)
2887+
check(assert_type(s_int.get(99, default="a"), Union[int, str]), str)
2888+
2889+
s_str = pd.Series(list("abc"), index=list("abc"))
2890+
2891+
check(assert_type(s_str.get("a"), Union[str, None]), str)
2892+
check(assert_type(s_str.get("z"), Union[str, None]), type(None))
2893+
check(assert_type(s_str.get("a", default=None), Union[str, None]), str)
2894+
check(assert_type(s_str.get("z", default=None), Union[str, None]), type(None))
2895+
check(assert_type(s_str.get("a", default="b"), str), str)
2896+
check(assert_type(s_str.get("z", default=True), Union[str, bool]), bool)
2897+
2898+
28782899
def test_series_new_empty() -> None:
28792900
# GH 826
28802901
check(assert_type(pd.Series(), "pd.Series[Any]"), pd.Series)

0 commit comments

Comments
 (0)