Skip to content

Commit 6f82178

Browse files
committed
refactor: simplify typing
1 parent 5b684a3 commit 6f82178

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

tests/test_frame.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
TYPE_CHECKING,
2626
Any,
2727
Generic,
28-
Protocol,
2928
TypeAlias,
3029
TypedDict,
3130
TypeVar,
@@ -72,6 +71,8 @@
7271

7372
if TYPE_CHECKING:
7473
from pandas.core.frame import _PandasNamedTuple
74+
75+
from pandas._typing import S1
7576
else:
7677
_PandasNamedTuple: TypeAlias = tuple
7778

@@ -828,7 +829,9 @@ def test_dataframe_clip() -> None:
828829
df.clip(lower=pd.Series([1, 2]), upper=pd.Series([4, 5]), axis=None) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
829830
df.copy().clip(lower=pd.Series([1, 2]), upper=None, axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
830831
df.copy().clip(lower=None, upper=pd.Series([1, 2]), axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
831-
df.copy().clip(lower=pd.Series([4, 5]), upper=pd.Series([1, 2]), axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
832+
df.copy().clip(
833+
lower=pd.Series([4, 5]), upper=pd.Series([1, 2]), axis=None, inplace=True
834+
) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
832835

833836
check(
834837
assert_type(df.clip(lower=None, upper=None, axis=None), pd.DataFrame),
@@ -1713,12 +1716,7 @@ def test_types_groupby_agg() -> None:
17131716
agg_dict1 = {"col2": "min", "col3": "max", 0: "sum"}
17141717
check(assert_type(df.groupby("col1").agg(agg_dict1), pd.DataFrame), pd.DataFrame)
17151718

1716-
T_co = TypeVar("T_co", covariant=True)
1717-
1718-
class SupportsMin(Protocol[T_co]):
1719-
def min(self) -> T_co: ...
1720-
1721-
def wrapped_min(x: SupportsMin[T_co]) -> T_co:
1719+
def wrapped_min(x: pd.Series[S1]) -> S1:
17221720
return x.min()
17231721

17241722
with pytest_warns_bounded(
@@ -4313,10 +4311,24 @@ def test_to_dict_index() -> None:
43134311
assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str
43144312
)
43154313
if TYPE_CHECKING_INVALID_USAGE:
4316-
check(assert_type(df.to_dict(orient="records", index=False), list[dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4317-
check(assert_type(df.to_dict(orient="dict", index=False), dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4318-
check(assert_type(df.to_dict(orient="series", index=False), dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4319-
check(assert_type(df.to_dict(orient="index", index=False), dict[Hashable, Any]), dict) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4314+
check(
4315+
assert_type(
4316+
df.to_dict(orient="records", index=False), list[dict[Hashable, Any]]
4317+
),
4318+
list,
4319+
) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4320+
check(
4321+
assert_type(df.to_dict(orient="dict", index=False), dict[Hashable, Any]),
4322+
dict,
4323+
) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4324+
check(
4325+
assert_type(df.to_dict(orient="series", index=False), dict[Hashable, Any]),
4326+
dict,
4327+
) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
4328+
check(
4329+
assert_type(df.to_dict(orient="index", index=False), dict[Hashable, Any]),
4330+
dict,
4331+
) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]
43204332

43214333

43224334
def test_suffix_prefix_index() -> None:
@@ -4420,7 +4432,9 @@ def test_interpolate_inplace() -> None:
44204432

44214433
def test_getitem_generator() -> None:
44224434
# GH 685
4423-
check(assert_type(DF[(f"col{i+1}" for i in range(2))], pd.DataFrame), pd.DataFrame)
4435+
check(
4436+
assert_type(DF[(f"col{i + 1}" for i in range(2))], pd.DataFrame), pd.DataFrame
4437+
)
44244438

44254439

44264440
def test_getitem_dict_keys() -> None:
@@ -4566,7 +4580,6 @@ def test_hashable_args() -> None:
45664580
test = ["test"]
45674581

45684582
with ensure_clean() as path:
4569-
45704583
df.to_stata(path, version=117, convert_strl=test)
45714584
df.to_stata(path, version=117, convert_strl=["test"])
45724585

@@ -4691,7 +4704,6 @@ def test_unstack() -> None:
46914704

46924705

46934706
def test_from_records() -> None:
4694-
46954707
# test with np.ndarray
46964708
arr = np.array([[1, "a"], [2, "b"]], dtype=object).reshape(2, 2)
46974709
check(assert_type(pd.DataFrame.from_records(arr), pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)