Skip to content

Commit 089f7f8

Browse files
authored
TYP: def validate_* (#47750)
1 parent f9346a6 commit 089f7f8

File tree

6 files changed

+81
-19
lines changed

6 files changed

+81
-19
lines changed

pandas/_libs/tslibs/offsets.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ class SingleConstructorOffset(BaseOffset):
105105
@overload
106106
def to_offset(freq: None) -> None: ...
107107
@overload
108-
def to_offset(freq: timedelta | BaseOffset | str) -> BaseOffset: ...
108+
def to_offset(freq: _BaseOffsetT) -> _BaseOffsetT: ...
109+
@overload
110+
def to_offset(freq: timedelta | str) -> BaseOffset: ...
109111

110112
class Tick(SingleConstructorOffset):
111113
_reso: int

pandas/compat/numpy/function.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,28 @@
1717
"""
1818
from __future__ import annotations
1919

20-
from typing import Any
20+
from typing import (
21+
Any,
22+
TypeVar,
23+
overload,
24+
)
2125

2226
from numpy import ndarray
2327

2428
from pandas._libs.lib import (
2529
is_bool,
2630
is_integer,
2731
)
32+
from pandas._typing import Axis
2833
from pandas.errors import UnsupportedFunctionCall
2934
from pandas.util._validators import (
3035
validate_args,
3136
validate_args_and_kwargs,
3237
validate_kwargs,
3338
)
3439

40+
AxisNoneT = TypeVar("AxisNoneT", Axis, None)
41+
3542

3643
class CompatValidator:
3744
def __init__(
@@ -84,15 +91,15 @@ def __call__(
8491
)
8592

8693

87-
def process_skipna(skipna, args):
94+
def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
8895
if isinstance(skipna, ndarray) or skipna is None:
8996
args = (skipna,) + args
9097
skipna = True
9198

9299
return skipna, args
93100

94101

95-
def validate_argmin_with_skipna(skipna, args, kwargs):
102+
def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
96103
"""
97104
If 'Series.argmin' is called via the 'numpy' library, the third parameter
98105
in its signature is 'out', which takes either an ndarray or 'None', so
@@ -104,7 +111,7 @@ def validate_argmin_with_skipna(skipna, args, kwargs):
104111
return skipna
105112

106113

107-
def validate_argmax_with_skipna(skipna, args, kwargs):
114+
def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
108115
"""
109116
If 'Series.argmax' is called via the 'numpy' library, the third parameter
110117
in its signature is 'out', which takes either an ndarray or 'None', so
@@ -137,7 +144,7 @@ def validate_argmax_with_skipna(skipna, args, kwargs):
137144
)
138145

139146

140-
def validate_argsort_with_ascending(ascending, args, kwargs):
147+
def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
141148
"""
142149
If 'Categorical.argsort' is called via the 'numpy' library, the first
143150
parameter in its signature is 'axis', which takes either an integer or
@@ -149,7 +156,8 @@ def validate_argsort_with_ascending(ascending, args, kwargs):
149156
ascending = True
150157

151158
validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
152-
return ascending
159+
# error: Incompatible return value type (got "int", expected "bool")
160+
return ascending # type: ignore[return-value]
153161

154162

155163
CLIP_DEFAULTS: dict[str, Any] = {"out": None}
@@ -158,7 +166,19 @@ def validate_argsort_with_ascending(ascending, args, kwargs):
158166
)
159167

160168

161-
def validate_clip_with_axis(axis, args, kwargs):
169+
@overload
170+
def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None:
171+
...
172+
173+
174+
@overload
175+
def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT:
176+
...
177+
178+
179+
def validate_clip_with_axis(
180+
axis: ndarray | AxisNoneT, args, kwargs
181+
) -> AxisNoneT | None:
162182
"""
163183
If 'NDFrame.clip' is called via the numpy library, the third parameter in
164184
its signature is 'out', which can takes an ndarray, so check if the 'axis'
@@ -167,10 +187,14 @@ def validate_clip_with_axis(axis, args, kwargs):
167187
"""
168188
if isinstance(axis, ndarray):
169189
args = (axis,) + args
170-
axis = None
190+
# error: Incompatible types in assignment (expression has type "None",
191+
# variable has type "Union[ndarray[Any, Any], str, int]")
192+
axis = None # type: ignore[assignment]
171193

172194
validate_clip(args, kwargs)
173-
return axis
195+
# error: Incompatible return value type (got "Union[ndarray[Any, Any],
196+
# str, int]", expected "Union[str, int, None]")
197+
return axis # type: ignore[return-value]
174198

175199

176200
CUM_FUNC_DEFAULTS: dict[str, Any] = {}
@@ -184,7 +208,7 @@ def validate_clip_with_axis(axis, args, kwargs):
184208
)
185209

186210

187-
def validate_cum_func_with_skipna(skipna, args, kwargs, name):
211+
def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
188212
"""
189213
If this function is called via the 'numpy' library, the third parameter in
190214
its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
@@ -288,7 +312,7 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
288312
validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
289313

290314

291-
def validate_take_with_convert(convert, args, kwargs):
315+
def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool:
292316
"""
293317
If this function is called via the 'numpy' library, the third parameter in
294318
its signature is 'axis', which takes either an ndarray or 'None', so check

pandas/core/arrays/datetimelike.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,7 +2164,17 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str):
21642164
return data, copy
21652165

21662166

2167-
def validate_periods(periods):
2167+
@overload
2168+
def validate_periods(periods: None) -> None:
2169+
...
2170+
2171+
2172+
@overload
2173+
def validate_periods(periods: int | float) -> int:
2174+
...
2175+
2176+
2177+
def validate_periods(periods: int | float | None) -> int | None:
21682178
"""
21692179
If a `periods` argument is passed to the Datetime/Timedelta Array/Index
21702180
constructor, cast it to an integer.
@@ -2187,7 +2197,9 @@ def validate_periods(periods):
21872197
periods = int(periods)
21882198
elif not lib.is_integer(periods):
21892199
raise TypeError(f"periods must be a number, got {periods}")
2190-
return periods
2200+
# error: Incompatible return value type (got "Optional[float]",
2201+
# expected "Optional[int]")
2202+
return periods # type: ignore[return-value]
21912203

21922204

21932205
def validate_inferred_freq(freq, inferred_freq, freq_infer):

pandas/core/arrays/datetimes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _scalar_type(self) -> type[Timestamp]:
251251
# Constructors
252252

253253
_dtype: np.dtype | DatetimeTZDtype
254-
_freq = None
254+
_freq: BaseOffset | None = None
255255
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__
256256

257257
@classmethod

pandas/core/arrays/period.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
Callable,
99
Literal,
1010
Sequence,
11+
TypeVar,
12+
overload,
1113
)
1214

1315
import numpy as np
@@ -92,6 +94,8 @@
9294
TimedeltaArray,
9395
)
9496

97+
BaseOffsetT = TypeVar("BaseOffsetT", bound=BaseOffset)
98+
9599

96100
_shared_doc_kwargs = {
97101
"klass": "PeriodArray",
@@ -976,7 +980,19 @@ def period_array(
976980
return PeriodArray._from_sequence(data, dtype=dtype)
977981

978982

979-
def validate_dtype_freq(dtype, freq):
983+
@overload
984+
def validate_dtype_freq(dtype, freq: BaseOffsetT) -> BaseOffsetT:
985+
...
986+
987+
988+
@overload
989+
def validate_dtype_freq(dtype, freq: timedelta | str | None) -> BaseOffset:
990+
...
991+
992+
993+
def validate_dtype_freq(
994+
dtype, freq: BaseOffsetT | timedelta | str | None
995+
) -> BaseOffsetT:
980996
"""
981997
If both a dtype and a freq are available, ensure they match. If only
982998
dtype is available, extract the implied freq.
@@ -996,7 +1012,10 @@ def validate_dtype_freq(dtype, freq):
9961012
IncompatibleFrequency : mismatch between dtype and freq
9971013
"""
9981014
if freq is not None:
999-
freq = to_offset(freq)
1015+
# error: Incompatible types in assignment (expression has type
1016+
# "BaseOffset", variable has type "Union[BaseOffsetT, timedelta,
1017+
# str, None]")
1018+
freq = to_offset(freq) # type: ignore[assignment]
10001019

10011020
if dtype is not None:
10021021
dtype = pandas_dtype(dtype)
@@ -1006,7 +1025,9 @@ def validate_dtype_freq(dtype, freq):
10061025
freq = dtype.freq
10071026
elif freq != dtype.freq:
10081027
raise IncompatibleFrequency("specified freq and dtype are different")
1009-
return freq
1028+
# error: Incompatible return value type (got "Union[BaseOffset, Any, None]",
1029+
# expected "BaseOffset")
1030+
return freq # type: ignore[return-value]
10101031

10111032

10121033
def dt64arr_to_periodarr(

pandas/util/_validators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
from typing import (
8+
Any,
89
Iterable,
910
Sequence,
1011
TypeVar,
@@ -265,7 +266,9 @@ def validate_bool_kwarg(
265266
return value
266267

267268

268-
def validate_axis_style_args(data, args, kwargs, arg_name, method_name):
269+
def validate_axis_style_args(
270+
data, args, kwargs, arg_name, method_name
271+
) -> dict[str, Any]:
269272
"""
270273
Argument handler for mixed index, columns / axis functions
271274

0 commit comments

Comments
 (0)