Skip to content

ENH: dt64/td64 comparison support non-nano #47691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/_libs/tslibs/np_datetime.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from pandas._typing import npt

class OutOfBoundsDatetime(ValueError): ...
class OutOfBoundsTimedelta(ValueError): ...

Expand All @@ -10,3 +12,6 @@ def astype_overflowsafe(
arr: np.ndarray, dtype: np.dtype, copy: bool = ...
) -> np.ndarray: ...
def is_unitless(dtype: np.dtype) -> bool: ...
def compare_mismatched_resolutions(
left: np.ndarray, right: np.ndarray, op
) -> npt.NDArray[np.bool_]: ...
80 changes: 80 additions & 0 deletions pandas/_libs/tslibs/np_datetime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ from cpython.object cimport (
import_datetime()

import numpy as np

cimport numpy as cnp

cnp.import_array()
from numpy cimport (
int64_t,
ndarray,
uint8_t,
)

from pandas._libs.tslibs.util cimport get_c_string_buf_and_size
Expand Down Expand Up @@ -370,3 +372,81 @@ cpdef ndarray astype_overflowsafe(
cnp.PyArray_MultiIter_NEXT(mi)

return iresult.view(dtype)


# TODO: try to upstream this fix to numpy
def compare_mismatched_resolutions(ndarray left, ndarray right, op):
"""
Overflow-safe comparison of timedelta64/datetime64 with mismatched resolutions.

>>> left = np.array([500], dtype="M8[Y]")
>>> right = np.array([0], dtype="M8[ns]")
>>> left < right # <- wrong!
array([ True])
"""

if left.dtype.kind != right.dtype.kind or left.dtype.kind not in ["m", "M"]:
raise ValueError("left and right must both be timedelta64 or both datetime64")

cdef:
int op_code = op_to_op_code(op)
NPY_DATETIMEUNIT left_unit = get_unit_from_dtype(left.dtype)
NPY_DATETIMEUNIT right_unit = get_unit_from_dtype(right.dtype)

# equiv: result = np.empty((<object>left).shape, dtype="bool")
ndarray result = cnp.PyArray_EMPTY(
left.ndim, left.shape, cnp.NPY_BOOL, 0
)

ndarray lvalues = left.view("i8")
ndarray rvalues = right.view("i8")

cnp.broadcast mi = cnp.PyArray_MultiIterNew3(result, lvalues, rvalues)
int64_t lval, rval
bint res_value

Py_ssize_t i, N = left.size
npy_datetimestruct ldts, rdts


for i in range(N):
# Analogous to: lval = lvalues[i]
lval = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 1))[0]

# Analogous to: rval = rvalues[i]
rval = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 2))[0]

if lval == NPY_DATETIME_NAT or rval == NPY_DATETIME_NAT:
res_value = op_code == Py_NE

else:
pandas_datetime_to_datetimestruct(lval, left_unit, &ldts)
pandas_datetime_to_datetimestruct(rval, right_unit, &rdts)

res_value = cmp_dtstructs(&ldts, &rdts, op_code)

# Analogous to: result[i] = res_value
(<uint8_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_value

cnp.PyArray_MultiIter_NEXT(mi)

return result


import operator


cdef int op_to_op_code(op):
# TODO: should exist somewhere?
if op is operator.eq:
return Py_EQ
if op is operator.ne:
return Py_NE
if op is operator.le:
return Py_LE
if op is operator.lt:
return Py_LT
if op is operator.ge:
return Py_GE
if op is operator.gt:
return Py_GT
19 changes: 19 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
RoundTo,
round_nsint64,
)
from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions
from pandas._libs.tslibs.timestamps import integer_op_not_supported
from pandas._typing import (
ArrayLike,
Expand Down Expand Up @@ -1065,6 +1066,24 @@ def _cmp_method(self, other, op):
)
return result

if other is NaT:
if op is operator.ne:
result = np.ones(self.shape, dtype=bool)
else:
result = np.zeros(self.shape, dtype=bool)
return result

if not is_period_dtype(self.dtype):
self = cast(TimelikeOps, self)
if self._reso != other._reso:
if not isinstance(other, type(self)):
# i.e. Timedelta/Timestamp, cast to ndarray and let
# compare_mismatched_resolutions handle broadcasting
other_arr = np.array(other.asm8)
else:
other_arr = other._ndarray
return compare_mismatched_resolutions(self._ndarray, other_arr, op)

other_vals = self._unbox(other)
# GH#37462 comparison on i8 values is almost 2x faster than M8/m8
result = op(self._ndarray.view("i8"), other_vals.view("i8"))
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Tests for DatetimeArray
"""
import operator

import numpy as np
import pytest

Expand Down Expand Up @@ -169,6 +171,42 @@ def test_repr(self, dta_dti, unit):

assert repr(dta) == repr(dti._data).replace("[ns", f"[{unit}")

# TODO: tests with td64
def test_compare_mismatched_resolutions(self, comparison_op):
# comparison that numpy gets wrong bc of silent overflows
op = comparison_op

iinfo = np.iinfo(np.int64)
vals = np.array([iinfo.min, iinfo.min + 1, iinfo.max], dtype=np.int64)

# Construct so that arr2[1] < arr[1] < arr[2] < arr2[2]
arr = np.array(vals).view("M8[ns]")
arr2 = arr.view("M8[s]")

left = DatetimeArray._simple_new(arr, dtype=arr.dtype)
right = DatetimeArray._simple_new(arr2, dtype=arr2.dtype)

if comparison_op is operator.eq:
expected = np.array([False, False, False])
elif comparison_op is operator.ne:
expected = np.array([True, True, True])
elif comparison_op in [operator.lt, operator.le]:
expected = np.array([False, False, True])
else:
expected = np.array([False, True, False])

result = op(left, right)
tm.assert_numpy_array_equal(result, expected)

result = op(left[1], right)
tm.assert_numpy_array_equal(result, expected)

if op not in [operator.eq, operator.ne]:
# check that numpy still gets this wrong; if it is fixed we may be
# able to remove compare_mismatched_resolutions
np_res = op(left._ndarray, right._ndarray)
tm.assert_numpy_array_equal(np_res[1:], ~expected[1:])


class TestDatetimeArrayComparisons:
# TODO: merge this into tests/arithmetic/test_datetime64 once it is
Expand Down