Skip to content

Commit 2fd9697

Browse files
authored
assert types at runtime (#114)
* assert types at runtime * use classes isntead of strings where possible * check dtype * check() * unused imports * check a few unused variables * attr
1 parent d8aef91 commit 2fd9697

File tree

8 files changed

+258
-182
lines changed

8 files changed

+258
-182
lines changed

pandas-stubs/__init__.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
import pandas.testing as testing
22

3+
from . import (
4+
api as api,
5+
arrays as arrays,
6+
errors as errors,
7+
io as io,
8+
plotting as plotting,
9+
testing as testing,
10+
tseries as tseries,
11+
)
312
from ._config import (
413
describe_option as describe_option,
514
get_option as get_option,

tests/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import annotations
2+
3+
4+
def check(
5+
actual: object, klass: type, dtype: type | None = None, attr: str = "left"
6+
) -> None:
7+
8+
if not isinstance(actual, klass):
9+
raise RuntimeError(f"Expected type '{klass}' but got '{type(actual)}'")
10+
if dtype is None:
11+
return None
12+
13+
if hasattr(actual, "__iter__"):
14+
value = next(iter(actual)) # type: ignore[call-overload]
15+
else:
16+
assert hasattr(actual, attr)
17+
value = getattr(actual, attr) # type: ignore[attr-defined]
18+
19+
if not isinstance(value, dtype):
20+
raise RuntimeError(f"Expected type '{dtype}' but got '{type(value)}'")
21+
return None

tests/test_frame.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121
from typing_extensions import assert_type
2222

23+
from tests import check
24+
2325
from pandas.io.parsers import TextFileReader
2426

2527

@@ -188,8 +190,8 @@ def test_types_assign() -> None:
188190
def test_types_sample() -> None:
189191
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
190192
# GH 67
191-
assert_type(df.sample(frac=0.5), pd.DataFrame)
192-
assert_type(df.sample(n=1), pd.DataFrame)
193+
check(assert_type(df.sample(frac=0.5), pd.DataFrame), pd.DataFrame)
194+
check(assert_type(df.sample(n=1), pd.DataFrame), pd.DataFrame)
193195

194196

195197
def test_types_nlargest_nsmallest() -> None:
@@ -576,10 +578,18 @@ def test_types_groupby_any() -> None:
576578
"col3": [False, False, False],
577579
}
578580
)
579-
assert_type(df.groupby("col1").any(), "pd.DataFrame")
580-
assert_type(df.groupby("col1").all(), "pd.DataFrame")
581-
assert_type(df.groupby("col1")["col2"].any(), "pd.Series[bool]")
582-
assert_type(df.groupby("col1")["col2"].any(), "pd.Series[bool]")
581+
check(assert_type(df.groupby("col1").any(), pd.DataFrame), pd.DataFrame)
582+
check(assert_type(df.groupby("col1").all(), pd.DataFrame), pd.DataFrame)
583+
check(
584+
assert_type(df.groupby("col1")["col2"].any(), "pd.Series[bool]"),
585+
pd.Series,
586+
bool,
587+
)
588+
check(
589+
assert_type(df.groupby("col1")["col2"].any(), "pd.Series[bool]"),
590+
pd.Series,
591+
bool,
592+
)
583593

584594

585595
def test_types_merge() -> None:
@@ -966,7 +976,7 @@ def test_read_csv() -> None:
966976
def test_groupby_series_methods() -> None:
967977
df = pd.DataFrame({"x": [1, 2, 2, 3, 3], "y": [10, 20, 30, 40, 50]})
968978
gb = df.groupby("x")["y"]
969-
assert_type(gb.describe(), "pd.DataFrame")
979+
check(assert_type(gb.describe(), pd.DataFrame), pd.DataFrame)
970980
gb.count().loc[2]
971981
gb.pct_change().loc[2]
972982
gb.bfill().loc[2]
@@ -1006,9 +1016,9 @@ def test_compute_values():
10061016
def test_sum_get_add() -> None:
10071017
df = pd.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]})
10081018
s = df["x"]
1009-
assert_type(s, "pd.Series")
1019+
check(assert_type(s, pd.Series), pd.Series)
10101020
summer = df.sum(axis=1)
1011-
assert_type(summer, "pd.Series")
1021+
check(assert_type(summer, pd.Series), pd.Series)
10121022

10131023
s2: pd.Series = s + summer
10141024
s3: pd.Series = s + df["y"]
@@ -1036,7 +1046,7 @@ def test_getmultiindex_columns() -> None:
10361046

10371047
def test_frame_getitem_isin() -> None:
10381048
df = pd.DataFrame({"x": [1, 2, 3, 4, 5]}, index=[1, 2, 3, 4, 5])
1039-
assert_type(df[df.index.isin([1, 3, 5])], "pd.DataFrame")
1049+
check(assert_type(df[df.index.isin([1, 3, 5])], pd.DataFrame), pd.DataFrame)
10401050

10411051

10421052
def test_read_excel() -> None:
@@ -1072,40 +1082,40 @@ def test_join() -> None:
10721082
def test_types_ffill() -> None:
10731083
# GH 44
10741084
df = pd.DataFrame([[1, 2, 3]])
1075-
assert_type(df.ffill(), pd.DataFrame)
1076-
assert_type(df.ffill(inplace=False), pd.DataFrame)
1077-
assert_type(df.ffill(inplace=True), None)
1085+
check(assert_type(df.ffill(), pd.DataFrame), pd.DataFrame)
1086+
check(assert_type(df.ffill(inplace=False), pd.DataFrame), pd.DataFrame)
1087+
assert assert_type(df.ffill(inplace=True), None) is None
10781088

10791089

10801090
def test_types_bfill() -> None:
10811091
# GH 44
10821092
df = pd.DataFrame([[1, 2, 3]])
1083-
assert_type(df.bfill(), pd.DataFrame)
1084-
assert_type(df.bfill(inplace=False), pd.DataFrame)
1085-
assert_type(df.bfill(inplace=True), None)
1093+
check(assert_type(df.bfill(), pd.DataFrame), pd.DataFrame)
1094+
check(assert_type(df.bfill(inplace=False), pd.DataFrame), pd.DataFrame)
1095+
assert assert_type(df.bfill(inplace=True), None) is None
10861096

10871097

10881098
def test_types_replace() -> None:
10891099
# GH 44
10901100
df = pd.DataFrame([[1, 2, 3]])
1091-
assert_type(df.replace(1, 2), pd.DataFrame)
1092-
assert_type(df.replace(1, 2, inplace=False), pd.DataFrame)
1093-
assert_type(df.replace(1, 2, inplace=True), None)
1101+
check(assert_type(df.replace(1, 2), pd.DataFrame), pd.DataFrame)
1102+
check(assert_type(df.replace(1, 2, inplace=False), pd.DataFrame), pd.DataFrame)
1103+
assert assert_type(df.replace(1, 2, inplace=True), None) is None
10941104

10951105

10961106
def test_loop_dataframe() -> None:
10971107
# GH 70
10981108
df = pd.DataFrame({"x": [1, 2, 3]})
10991109
for c in df:
1100-
assert_type(df[c], pd.Series)
1110+
check(assert_type(df[c], pd.Series), pd.Series)
11011111

11021112

11031113
def test_groupby_index() -> None:
11041114
# GH 42
11051115
df = pd.DataFrame(
11061116
data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]}
11071117
).set_index("col1")
1108-
assert_type(df.groupby(df.index).min(), pd.DataFrame)
1118+
check(assert_type(df.groupby(df.index).min(), pd.DataFrame), pd.DataFrame)
11091119

11101120

11111121
def test_iloc_npint() -> None:

tests/test_indexes.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,52 @@
33
import pandas as pd
44
from typing_extensions import assert_type
55

6+
from tests import check
7+
68

79
def test_index_unique() -> None:
810

911
df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=pd.Index([1, 2, 3, 2]))
1012
ind = df.index
11-
assert_type(ind, "pd.Index")
13+
check(assert_type(ind, pd.Index), pd.Index)
1214
i2 = ind.unique()
13-
assert_type(i2, "pd.Index")
15+
check(assert_type(i2, pd.Index), pd.Index)
1416

1517

1618
def test_index_isin() -> None:
1719
ind = pd.Index([1, 2, 3, 4, 5])
1820
isin = ind.isin([2, 4])
19-
assert_type(isin, npt.NDArray[np.bool_])
21+
check(assert_type(isin, npt.NDArray[np.bool_]), np.ndarray, np.bool_)
2022

2123

2224
def test_index_astype() -> None:
2325
indi = pd.Index([1, 2, 3])
2426
inds = pd.Index(["a", "b", "c"])
2527
indc = indi.astype(inds.dtype)
26-
assert_type(indc, "pd.Index")
28+
check(assert_type(indc, pd.Index), pd.Index)
2729
mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"])
2830
mia = mi.astype(object) # object is only valid parameter for MultiIndex.astype()
29-
assert_type(mia, "pd.MultiIndex")
31+
check(assert_type(mia, pd.MultiIndex), pd.MultiIndex)
3032

3133

3234
def test_multiindex_get_level_values() -> None:
3335
mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"])
3436
i1 = mi.get_level_values("ab")
35-
assert_type(i1, "pd.Index")
37+
check(assert_type(i1, pd.Index), pd.Index)
3638

3739

3840
def test_index_tolist() -> None:
3941
i1 = pd.Index([1, 2, 3])
40-
l1 = i1.tolist()
41-
i2 = i1.to_list()
42+
check(assert_type(i1.tolist(), list), list, int)
43+
check(assert_type(i1.to_list(), list), list, int)
4244

4345

4446
def test_column_getitem() -> None:
4547
# https://github.com/microsoft/python-type-stubs/issues/199#issuecomment-1132806594
4648
df = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"])
4749

4850
column = df.columns[0]
49-
a = df[column]
51+
check(assert_type(df[column], pd.Series), pd.Series, int)
5052

5153

5254
def test_column_contains() -> None:
@@ -63,4 +65,4 @@ def test_column_contains() -> None:
6365
def test_difference_none() -> None:
6466
# https://github.com/pandas-dev/pandas-stubs/issues/17
6567
ind = pd.Index([1, 2, 3])
66-
id = ind.difference([1, None])
68+
check(assert_type(ind.difference([1, None]), "pd.Index"), pd.Index, int)

tests/test_interval.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# flake8: noqa: F841
22
from typing import TYPE_CHECKING
33

4+
import numpy as np
45
import pandas as pd
56
from typing_extensions import assert_type
67

8+
from tests import check
9+
710

811
def test_interval_init() -> None:
912
i1: pd.Interval = pd.Interval(1, 2, closed="both")
@@ -36,49 +39,49 @@ def test_interval_length() -> None:
3639
i1 = pd.Interval(
3740
pd.Timestamp("2000-01-01"), pd.Timestamp("2000-01-03"), closed="both"
3841
)
39-
assert_type(i1.length, "pd.Timedelta")
40-
assert_type(i1.left, "pd.Timestamp")
41-
assert_type(i1.right, "pd.Timestamp")
42-
assert_type(i1.mid, "pd.Timestamp")
42+
check(assert_type(i1.length, pd.Timedelta), pd.Timedelta)
43+
check(assert_type(i1.left, pd.Timestamp), pd.Timestamp)
44+
check(assert_type(i1.right, pd.Timestamp), pd.Timestamp)
45+
check(assert_type(i1.mid, pd.Timestamp), pd.Timestamp)
4346
i1.length.total_seconds()
4447
inres = pd.Timestamp("2001-01-02") in i1
45-
assert_type(inres, "bool")
48+
check(assert_type(inres, bool), bool)
4649
idres = i1 + pd.Timedelta(seconds=20)
4750

48-
assert_type(idres, "pd.Interval[pd.Timestamp]")
51+
check(assert_type(idres, "pd.Interval[pd.Timestamp]"), pd.Interval, pd.Timestamp)
4952
if TYPE_CHECKING:
5053
20 in i1 # type: ignore[operator]
5154
i1 + pd.Timestamp("2000-03-03") # type: ignore[operator]
5255
i1 * 3 # type: ignore[operator]
5356
i1 * pd.Timedelta(seconds=20) # type: ignore[operator]
5457

5558
i2 = pd.Interval(10, 20)
56-
assert_type(i2.length, "int")
57-
assert_type(i2.left, "int")
58-
assert_type(i2.right, "int")
59-
assert_type(i2.mid, "float")
59+
check(assert_type(i2.length, int), int)
60+
check(assert_type(i2.left, int), int)
61+
check(assert_type(i2.right, int), int)
62+
check(assert_type(i2.mid, float), float)
6063

6164
i2inres = 15 in i2
62-
assert_type(i2inres, "bool")
63-
assert_type(i2 + 3, "pd.Interval[int]")
64-
assert_type(i2 + 3.2, "pd.Interval[float]")
65-
assert_type(i2 * 4, "pd.Interval[int]")
66-
assert_type(i2 * 4.2, "pd.Interval[float]")
65+
check(assert_type(i2inres, bool), bool)
66+
check(assert_type(i2 + 3, "pd.Interval[int]"), pd.Interval, int)
67+
check(assert_type(i2 + 3.2, "pd.Interval[float]"), pd.Interval, float)
68+
check(assert_type(i2 * 4, "pd.Interval[int]"), pd.Interval, int)
69+
check(assert_type(i2 * 4.2, "pd.Interval[float]"), pd.Interval, float)
6770

6871
if TYPE_CHECKING:
6972
pd.Timestamp("2001-01-02") in i2 # type: ignore[operator]
7073
i2 + pd.Timedelta(seconds=20) # type: ignore[operator]
7174

7275
i3 = pd.Interval(13.2, 19.5)
73-
assert_type(i3.length, "float")
74-
assert_type(i3.left, "float")
75-
assert_type(i3.right, "float")
76-
assert_type(i3.mid, "float")
76+
check(assert_type(i3.length, float), float)
77+
check(assert_type(i3.left, float), float)
78+
check(assert_type(i3.right, float), float)
79+
check(assert_type(i3.mid, float), float)
7780

7881
i3inres = 15.4 in i3
79-
assert_type(i3inres, "bool")
80-
assert_type(i3 + 3, "pd.Interval[float]")
81-
assert_type(i3 * 3, "pd.Interval[float]")
82+
check(assert_type(i3inres, bool), bool)
83+
check(assert_type(i3 + 3, "pd.Interval[float]"), pd.Interval, float)
84+
check(assert_type(i3 * 3, "pd.Interval[float]"), pd.Interval, float)
8285
if TYPE_CHECKING:
8386
pd.Timestamp("2001-01-02") in i3 # type: ignore[operator]
8487
i3 + pd.Timedelta(seconds=20) # type: ignore[operator]

0 commit comments

Comments
 (0)