Skip to content

Commit cb0bcac

Browse files
bashtageKevin Sheppard
and
Kevin Sheppard
authored
ENH: Improve DataFrame loc indexing (#138)
* ENH: Improve DataFrame loc indexing Allow any scalar types in .loc to return a Series xref #133 * TST: Add check Co-authored-by: Kevin Sheppard <[email protected]>
1 parent b5d2e7a commit cb0bcac

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,10 @@ class _LocIndexerFrame(_LocIndexer):
138138
@overload
139139
def __getitem__(
140140
self,
141-
idx: Union[int, StrLike],
142-
) -> Series: ...
143-
@overload
144-
def __getitem__(
145-
self,
146-
idx: Tuple[Union[IndexType, MaskType], StrLike],
147-
) -> Series: ...
148-
@overload
149-
def __getitem__(
150-
self,
151-
idx: Tuple[Tuple[slice, ...], StrLike],
141+
idx: Union[
142+
Union[ScalarT, None],
143+
Tuple[Union[IndexType, MaskType, Tuple[slice, ...]], Union[ScalarT, None]],
144+
],
152145
) -> Series: ...
153146
@overload
154147
def __setitem__(

tests/test_frame.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# flake8: noqa: F841
2-
from datetime import date
2+
import datetime
33
import io
44
from pathlib import Path
55
import tempfile
@@ -943,7 +943,7 @@ def test_types_regressions() -> None:
943943
ss2: pd.Series = pd.concat([s1, s2])
944944

945945
# https://github.com/microsoft/python-type-stubs/issues/110
946-
d: date = pd.Timestamp("2021-01-01")
946+
d: datetime.date = pd.Timestamp("2021-01-01")
947947
tslist: List[pd.Timestamp] = list(pd.to_datetime(["2022-01-01", "2022-01-02"]))
948948
sseries: pd.Series = pd.Series(tslist)
949949
sseries_plus1: pd.Series = sseries + pd.Timedelta(1, "d")
@@ -1180,3 +1180,72 @@ def test_columns_mixlist() -> None:
11801180
key: List[Union[int, str]]
11811181
key = [1]
11821182
check(assert_type(df[key], pd.DataFrame), pd.DataFrame)
1183+
1184+
1185+
def test_frame_scalars_slice() -> None:
1186+
# GH 133
1187+
# scalars:
1188+
# str, bytes, datetime.date, datetime.datetime, datetime.timedelta, bool, int,
1189+
# float, complex, Timestamp, Timedelta
1190+
1191+
str_ = "a"
1192+
bytes_ = b"7"
1193+
date = datetime.date(1999, 12, 31)
1194+
datetime_ = datetime.datetime(1999, 12, 31)
1195+
timedelta = datetime.datetime(2000, 1, 1) - datetime.datetime(1999, 12, 31)
1196+
bool_ = True
1197+
int_ = 2
1198+
float_ = 3.14
1199+
complex_ = 1.0 + 3.0j
1200+
timestamp = pd.Timestamp(0)
1201+
pd_timedelta = pd.Timedelta(0, unit="D")
1202+
none = None
1203+
idx = [
1204+
str_,
1205+
bytes_,
1206+
date,
1207+
datetime_,
1208+
timedelta,
1209+
bool_,
1210+
int_,
1211+
float_,
1212+
complex_,
1213+
timestamp,
1214+
pd_timedelta,
1215+
none,
1216+
]
1217+
values = np.arange(len(idx))[:, None] + np.arange(len(idx))
1218+
df = pd.DataFrame(values, columns=idx, index=idx)
1219+
1220+
# Note: bool_ cannot be tested since the index is object and pandas does not
1221+
# support boolean access using loc except when the index is boolean
1222+
check(assert_type(df.loc[str_], pd.Series), pd.Series)
1223+
check(assert_type(df.loc[bytes_], pd.Series), pd.Series)
1224+
check(assert_type(df.loc[date], pd.Series), pd.Series)
1225+
check(assert_type(df.loc[datetime_], pd.Series), pd.Series)
1226+
check(assert_type(df.loc[timedelta], pd.Series), pd.Series)
1227+
check(assert_type(df.loc[int_], pd.Series), pd.Series)
1228+
check(assert_type(df.loc[float_], pd.Series), pd.Series)
1229+
check(assert_type(df.loc[complex_], pd.Series), pd.Series)
1230+
check(assert_type(df.loc[timestamp], pd.Series), pd.Series)
1231+
check(assert_type(df.loc[pd_timedelta], pd.Series), pd.Series)
1232+
check(assert_type(df.loc[none], pd.Series), pd.Series)
1233+
1234+
check(assert_type(df.loc[:, str_], pd.Series), pd.Series)
1235+
check(assert_type(df.loc[:, bytes_], pd.Series), pd.Series)
1236+
check(assert_type(df.loc[:, date], pd.Series), pd.Series)
1237+
check(assert_type(df.loc[:, datetime_], pd.Series), pd.Series)
1238+
check(assert_type(df.loc[:, timedelta], pd.Series), pd.Series)
1239+
check(assert_type(df.loc[:, int_], pd.Series), pd.Series)
1240+
check(assert_type(df.loc[:, float_], pd.Series), pd.Series)
1241+
check(assert_type(df.loc[:, complex_], pd.Series), pd.Series)
1242+
check(assert_type(df.loc[:, timestamp], pd.Series), pd.Series)
1243+
check(assert_type(df.loc[:, pd_timedelta], pd.Series), pd.Series)
1244+
check(assert_type(df.loc[:, none], pd.Series), pd.Series)
1245+
1246+
1247+
def test_boolean_loc() -> None:
1248+
# Booleans can only be used in loc when the index is boolean
1249+
df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False])
1250+
check(assert_type(df.loc[True], pd.Series), pd.Series)
1251+
check(assert_type(df.loc[:, False], pd.Series), pd.Series)

0 commit comments

Comments
 (0)