Skip to content

Commit d42a148

Browse files
authored
TYP: overload maybe_downcast_numeric and maybe_downcast_to_dtype (#46929)
1 parent d4ba130 commit d42a148

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

pandas/core/dtypes/cast.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,16 @@ def _disallow_mismatched_datetimelike(value, dtype: DtypeObj):
246246
raise TypeError(f"Cannot cast {repr(value)} to {dtype}")
247247

248248

249+
@overload
250+
def maybe_downcast_to_dtype(result: np.ndarray, dtype: str | np.dtype) -> np.ndarray:
251+
...
252+
253+
254+
@overload
255+
def maybe_downcast_to_dtype(result: ExtensionArray, dtype: str | np.dtype) -> ArrayLike:
256+
...
257+
258+
249259
def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLike:
250260
"""
251261
try to cast to the specified dtype (e.g. convert back to bool/int
@@ -301,6 +311,20 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
301311
return result
302312

303313

314+
@overload
315+
def maybe_downcast_numeric(
316+
result: np.ndarray, dtype: np.dtype, do_round: bool = False
317+
) -> np.ndarray:
318+
...
319+
320+
321+
@overload
322+
def maybe_downcast_numeric(
323+
result: ExtensionArray, dtype: DtypeObj, do_round: bool = False
324+
) -> ArrayLike:
325+
...
326+
327+
304328
def maybe_downcast_numeric(
305329
result: ArrayLike, dtype: DtypeObj, do_round: bool = False
306330
) -> ArrayLike:

pandas/core/groupby/ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,7 @@ def _call_cython_op(
637637
else:
638638
op_result = result
639639

640-
# error: Incompatible return value type (got "Union[ExtensionArray, ndarray]",
641-
# expected "ndarray")
642-
return op_result # type: ignore[return-value]
640+
return op_result
643641

644642
@final
645643
def cython_operation(

pandas/core/indexes/interval.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,11 +1141,7 @@ def interval_range(
11411141
if all(is_integer(x) for x in com.not_none(start, end, freq)):
11421142
# np.linspace always produces float output
11431143

1144-
# error: Incompatible types in assignment (expression has type
1145-
# "Union[ExtensionArray, ndarray]", variable has type "ndarray")
1146-
breaks = maybe_downcast_numeric( # type: ignore[assignment]
1147-
breaks, np.dtype("int64")
1148-
)
1144+
breaks = maybe_downcast_numeric(breaks, np.dtype("int64"))
11491145
else:
11501146
# delegate to the appropriate range function
11511147
if isinstance(endpoint, Timestamp):

0 commit comments

Comments
 (0)