diff --git a/pandas/_libs/lib.pyi b/pandas/_libs/lib.pyi index 077d2e60cc3a4..d39c5ac5d2967 100644 --- a/pandas/_libs/lib.pyi +++ b/pandas/_libs/lib.pyi @@ -219,8 +219,7 @@ def array_equivalent_object( left: np.ndarray, # object[:] right: np.ndarray, # object[:] ) -> bool: ... -def has_infs_f8(arr: np.ndarray) -> bool: ... # const float64_t[:] -def has_infs_f4(arr: np.ndarray) -> bool: ... # const float32_t[:] +def has_infs(arr: np.ndarray) -> bool: ... # const floating[:] def get_reverse_indexer( indexer: np.ndarray, # const intp_t[:] length: int, diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 506ad0102e157..8d476489bffb3 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -25,6 +25,7 @@ from cpython.tuple cimport ( PyTuple_New, PyTuple_SET_ITEM, ) +from cython cimport floating PyDateTime_IMPORT @@ -519,36 +520,22 @@ def get_reverse_indexer(const intp_t[:] indexer, Py_ssize_t length) -> ndarray: @cython.wraparound(False) @cython.boundscheck(False) -def has_infs_f4(const float32_t[:] arr) -> bool: +# Can add const once https://github.com/cython/cython/issues/1772 resolved +def has_infs(floating[:] arr) -> bool: cdef: Py_ssize_t i, n = len(arr) - float32_t inf, neginf, val + floating inf, neginf, val + bint ret = False inf = np.inf neginf = -inf - - for i in range(n): - val = arr[i] - if val == inf or val == neginf: - return True - return False - - -@cython.wraparound(False) -@cython.boundscheck(False) -def has_infs_f8(const float64_t[:] arr) -> bool: - cdef: - Py_ssize_t i, n = len(arr) - float64_t inf, neginf, val - - inf = np.inf - neginf = -inf - - for i in range(n): - val = arr[i] - if val == inf or val == neginf: - return True - return False + with nogil: + for i in range(n): + val = arr[i] + if val == inf or val == neginf: + ret = True + break + return ret def maybe_indices_to_slice(ndarray[intp_t] indices, int max_len): diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index c34944985f2b6..fb6580bbb7ea6 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -177,10 +177,8 @@ def _bn_ok_dtype(dtype: DtypeObj, name: str) -> bool: def _has_infs(result) -> bool: if isinstance(result, np.ndarray): - if result.dtype == "f8": - return lib.has_infs_f8(result.ravel("K")) - elif result.dtype == "f4": - return lib.has_infs_f4(result.ravel("K")) + if result.dtype == "f8" or result.dtype == "f4": + return lib.has_infs(result.ravel("K")) try: return np.isinf(result).any() except (TypeError, NotImplementedError):