From 99b297552101ff01dbc8ead1672cd5e56d2b6b09 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 12 Jul 2022 16:30:26 -0700 Subject: [PATCH 1/2] ENH: dt64/td64 comparison support non-nano --- pandas/_libs/tslibs/np_datetime.pyx | 80 +++++++++++++++++++++++++++ pandas/core/arrays/datetimelike.py | 18 ++++++ pandas/tests/arrays/test_datetimes.py | 38 +++++++++++++ 3 files changed, 136 insertions(+) diff --git a/pandas/_libs/tslibs/np_datetime.pyx b/pandas/_libs/tslibs/np_datetime.pyx index 1aab5dcd6f70b..692b4430fa577 100644 --- a/pandas/_libs/tslibs/np_datetime.pyx +++ b/pandas/_libs/tslibs/np_datetime.pyx @@ -20,12 +20,14 @@ from cpython.object cimport ( import_datetime() import numpy as np + cimport numpy as cnp cnp.import_array() from numpy cimport ( int64_t, ndarray, + uint8_t, ) from pandas._libs.tslibs.util cimport get_c_string_buf_and_size @@ -370,3 +372,81 @@ cpdef ndarray astype_overflowsafe( cnp.PyArray_MultiIter_NEXT(mi) return iresult.view(dtype) + + +# TODO: try to upstream this fix to numpy +def compare_mismatched_resolutions(ndarray left, ndarray right, op): + """ + Overflow-safe comparison of timedelta64/datetime64 with mismatched resolutions. + + >>> left = np.array([500], dtype="M8[Y]") + >>> right = np.array([0], dtype="M8[ns]") + >>> left < right # <- wrong! + array([ True]) + """ + + if left.dtype.kind != right.dtype.kind or left.dtype.kind not in ["m", "M"]: + raise ValueError("left and right must both be timedelta64 or both datetime64") + + cdef: + int op_code = op_to_op_code(op) + NPY_DATETIMEUNIT left_unit = get_unit_from_dtype(left.dtype) + NPY_DATETIMEUNIT right_unit = get_unit_from_dtype(right.dtype) + + # equiv: result = np.empty((left).shape, dtype="bool") + ndarray result = cnp.PyArray_EMPTY( + left.ndim, left.shape, cnp.NPY_BOOL, 0 + ) + + ndarray lvalues = left.view("i8") + ndarray rvalues = right.view("i8") + + cnp.broadcast mi = cnp.PyArray_MultiIterNew3(result, lvalues, rvalues) + int64_t lval, rval + bint res_value + + Py_ssize_t i, N = left.size + npy_datetimestruct ldts, rdts + + + for i in range(N): + # Analogous to: lval = lvalues[i] + lval = (cnp.PyArray_MultiIter_DATA(mi, 1))[0] + + # Analogous to: rval = rvalues[i] + rval = (cnp.PyArray_MultiIter_DATA(mi, 2))[0] + + if lval == NPY_DATETIME_NAT or rval == NPY_DATETIME_NAT: + res_value = op_code == Py_NE + + else: + pandas_datetime_to_datetimestruct(lval, left_unit, &ldts) + pandas_datetime_to_datetimestruct(rval, right_unit, &rdts) + + res_value = cmp_dtstructs(&ldts, &rdts, op_code) + + # Analogous to: result[i] = res_value + (cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_value + + cnp.PyArray_MultiIter_NEXT(mi) + + return result + + +import operator + + +cdef int op_to_op_code(op): + # TODO: should exist somewhere? + if op is operator.eq: + return Py_EQ + if op is operator.ne: + return Py_NE + if op is operator.le: + return Py_LE + if op is operator.lt: + return Py_LT + if op is operator.ge: + return Py_GE + if op is operator.gt: + return Py_GT diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index eadf47b36d7fc..1aad48ce83d97 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -46,6 +46,7 @@ RoundTo, round_nsint64, ) +from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions from pandas._libs.tslibs.timestamps import integer_op_not_supported from pandas._typing import ( ArrayLike, @@ -1065,6 +1066,23 @@ def _cmp_method(self, other, op): ) return result + if other is NaT: + if op is operator.ne: + result = np.ones(self.shape, dtype=bool) + else: + result = np.zeros(self.shape, dtype=bool) + return result + + if not is_period_dtype(self.dtype): + if self._reso != other._reso: + if not isinstance(other, type(self)): + # i.e. Timedelta/Timestamp, cast to ndarray and let + # compare_mismatched_resolutions handle broadcasting + other_arr = np.array(other.asm8) + else: + other_arr = other._ndarray + return compare_mismatched_resolutions(self._ndarray, other_arr, op) + other_vals = self._unbox(other) # GH#37462 comparison on i8 values is almost 2x faster than M8/m8 result = op(self._ndarray.view("i8"), other_vals.view("i8")) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 63601ff963609..af1a292a2975a 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -1,6 +1,8 @@ """ Tests for DatetimeArray """ +import operator + import numpy as np import pytest @@ -169,6 +171,42 @@ def test_repr(self, dta_dti, unit): assert repr(dta) == repr(dti._data).replace("[ns", f"[{unit}") + # TODO: tests with td64 + def test_compare_mismatched_resolutions(self, comparison_op): + # comparison that numpy gets wrong bc of silent overflows + op = comparison_op + + iinfo = np.iinfo(np.int64) + vals = np.array([iinfo.min, iinfo.min + 1, iinfo.max], dtype=np.int64) + + # Construct so that arr2[1] < arr[1] < arr[2] < arr2[2] + arr = np.array(vals).view("M8[ns]") + arr2 = arr.view("M8[s]") + + left = DatetimeArray._simple_new(arr, dtype=arr.dtype) + right = DatetimeArray._simple_new(arr2, dtype=arr2.dtype) + + if comparison_op is operator.eq: + expected = np.array([False, False, False]) + elif comparison_op is operator.ne: + expected = np.array([True, True, True]) + elif comparison_op in [operator.lt, operator.le]: + expected = np.array([False, False, True]) + else: + expected = np.array([False, True, False]) + + result = op(left, right) + tm.assert_numpy_array_equal(result, expected) + + result = op(left[1], right) + tm.assert_numpy_array_equal(result, expected) + + if op not in [operator.eq, operator.ne]: + # check that numpy still gets this wrong; if it is fixed we may be + # able to remove compare_mismatched_resolutions + np_res = op(left._ndarray, right._ndarray) + tm.assert_numpy_array_equal(np_res[1:], ~expected[1:]) + class TestDatetimeArrayComparisons: # TODO: merge this into tests/arithmetic/test_datetime64 once it is From 8bea96c5dd5d2cfddf91bacefb7a8a1bca26bdca Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 12 Jul 2022 19:11:17 -0700 Subject: [PATCH 2/2] mypy fixup --- pandas/_libs/tslibs/np_datetime.pyi | 5 +++++ pandas/core/arrays/datetimelike.py | 1 + 2 files changed, 6 insertions(+) diff --git a/pandas/_libs/tslibs/np_datetime.pyi b/pandas/_libs/tslibs/np_datetime.pyi index 27871a78f8aaf..757165fbad268 100644 --- a/pandas/_libs/tslibs/np_datetime.pyi +++ b/pandas/_libs/tslibs/np_datetime.pyi @@ -1,5 +1,7 @@ import numpy as np +from pandas._typing import npt + class OutOfBoundsDatetime(ValueError): ... class OutOfBoundsTimedelta(ValueError): ... @@ -10,3 +12,6 @@ def astype_overflowsafe( arr: np.ndarray, dtype: np.dtype, copy: bool = ... ) -> np.ndarray: ... def is_unitless(dtype: np.dtype) -> bool: ... +def compare_mismatched_resolutions( + left: np.ndarray, right: np.ndarray, op +) -> npt.NDArray[np.bool_]: ... diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 1aad48ce83d97..0f88ad9811bf0 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1074,6 +1074,7 @@ def _cmp_method(self, other, op): return result if not is_period_dtype(self.dtype): + self = cast(TimelikeOps, self) if self._reso != other._reso: if not isinstance(other, type(self)): # i.e. Timedelta/Timestamp, cast to ndarray and let