diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 2b812a17f..203190707 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -138,17 +138,10 @@ class _LocIndexerFrame(_LocIndexer): @overload def __getitem__( self, - idx: Union[int, StrLike], - ) -> Series: ... - @overload - def __getitem__( - self, - idx: Tuple[Union[IndexType, MaskType], StrLike], - ) -> Series: ... - @overload - def __getitem__( - self, - idx: Tuple[Tuple[slice, ...], StrLike], + idx: Union[ + Union[ScalarT, None], + Tuple[Union[IndexType, MaskType, Tuple[slice, ...]], Union[ScalarT, None]], + ], ) -> Series: ... @overload def __setitem__( diff --git a/tests/test_frame.py b/tests/test_frame.py index 6d7731f01..b76df3b4e 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1,5 +1,5 @@ # flake8: noqa: F841 -from datetime import date +import datetime import io from pathlib import Path import tempfile @@ -943,7 +943,7 @@ def test_types_regressions() -> None: ss2: pd.Series = pd.concat([s1, s2]) # https://github.com/microsoft/python-type-stubs/issues/110 - d: date = pd.Timestamp("2021-01-01") + d: datetime.date = pd.Timestamp("2021-01-01") tslist: List[pd.Timestamp] = list(pd.to_datetime(["2022-01-01", "2022-01-02"])) sseries: pd.Series = pd.Series(tslist) sseries_plus1: pd.Series = sseries + pd.Timedelta(1, "d") @@ -1180,3 +1180,72 @@ def test_columns_mixlist() -> None: key: List[Union[int, str]] key = [1] check(assert_type(df[key], pd.DataFrame), pd.DataFrame) + + +def test_frame_scalars_slice() -> None: + # GH 133 + # scalars: + # str, bytes, datetime.date, datetime.datetime, datetime.timedelta, bool, int, + # float, complex, Timestamp, Timedelta + + str_ = "a" + bytes_ = b"7" + date = datetime.date(1999, 12, 31) + datetime_ = datetime.datetime(1999, 12, 31) + timedelta = datetime.datetime(2000, 1, 1) - datetime.datetime(1999, 12, 31) + bool_ = True + int_ = 2 + float_ = 3.14 + complex_ = 1.0 + 3.0j + timestamp = pd.Timestamp(0) + pd_timedelta = pd.Timedelta(0, unit="D") + none = None + idx = [ + str_, + bytes_, + date, + datetime_, + timedelta, + bool_, + int_, + float_, + complex_, + timestamp, + pd_timedelta, + none, + ] + values = np.arange(len(idx))[:, None] + np.arange(len(idx)) + df = pd.DataFrame(values, columns=idx, index=idx) + + # Note: bool_ cannot be tested since the index is object and pandas does not + # support boolean access using loc except when the index is boolean + check(assert_type(df.loc[str_], pd.Series), pd.Series) + check(assert_type(df.loc[bytes_], pd.Series), pd.Series) + check(assert_type(df.loc[date], pd.Series), pd.Series) + check(assert_type(df.loc[datetime_], pd.Series), pd.Series) + check(assert_type(df.loc[timedelta], pd.Series), pd.Series) + check(assert_type(df.loc[int_], pd.Series), pd.Series) + check(assert_type(df.loc[float_], pd.Series), pd.Series) + check(assert_type(df.loc[complex_], pd.Series), pd.Series) + check(assert_type(df.loc[timestamp], pd.Series), pd.Series) + check(assert_type(df.loc[pd_timedelta], pd.Series), pd.Series) + check(assert_type(df.loc[none], pd.Series), pd.Series) + + check(assert_type(df.loc[:, str_], pd.Series), pd.Series) + check(assert_type(df.loc[:, bytes_], pd.Series), pd.Series) + check(assert_type(df.loc[:, date], pd.Series), pd.Series) + check(assert_type(df.loc[:, datetime_], pd.Series), pd.Series) + check(assert_type(df.loc[:, timedelta], pd.Series), pd.Series) + check(assert_type(df.loc[:, int_], pd.Series), pd.Series) + check(assert_type(df.loc[:, float_], pd.Series), pd.Series) + check(assert_type(df.loc[:, complex_], pd.Series), pd.Series) + check(assert_type(df.loc[:, timestamp], pd.Series), pd.Series) + check(assert_type(df.loc[:, pd_timedelta], pd.Series), pd.Series) + check(assert_type(df.loc[:, none], pd.Series), pd.Series) + + +def test_boolean_loc() -> None: + # Booleans can only be used in loc when the index is boolean + df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False]) + check(assert_type(df.loc[True], pd.Series), pd.Series) + check(assert_type(df.loc[:, False], pd.Series), pd.Series)