diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bc0e5092d5b..37d04cf7f1a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -93,6 +93,12 @@ Internal Changes - Changed test_open_mfdataset_list_attr to only run with dask installed (:issue:`3777`, :pull:`3780`). By `Bruno Pagani `_. +- Preserved the ability to index with ``method="nearest"`` with a + :py:class:`CFTimeIndex` with pandas versions greater than 1.0.1 + (:issue:`3751`). By `Spencer Clark `_. +- Greater flexibility and improved test coverage of subtracting various types + of objects from a :py:class:`CFTimeIndex`. By `Spencer Clark + `_. - Updated Azure CI MacOS image, given pending removal. By `Maximilian Roos `_ - Removed xfails for scipy 1.0.1 for tests that append to netCDF files (:pull:`3805`). diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 99f90430e91..1ea5d3a7d11 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -49,6 +49,7 @@ from xarray.core.utils import is_scalar +from ..core.common import _contains_cftime_datetimes from .times import _STANDARD_CALENDARS, cftime_to_nptime, infer_calendar_name @@ -326,6 +327,32 @@ def _get_string_slice(self, key): raise KeyError(key) return loc + def _get_nearest_indexer(self, target, limit, tolerance): + """Adapted from pandas.Index._get_nearest_indexer""" + left_indexer = self.get_indexer(target, "pad", limit=limit) + right_indexer = self.get_indexer(target, "backfill", limit=limit) + left_distances = abs(self.values[left_indexer] - target.values) + right_distances = abs(self.values[right_indexer] - target.values) + + if self.is_monotonic_increasing: + condition = (left_distances < right_distances) | (right_indexer == -1) + else: + condition = (left_distances <= right_distances) | (right_indexer == -1) + indexer = np.where(condition, left_indexer, right_indexer) + + if tolerance is not None: + indexer = self._filter_indexer_tolerance(target, indexer, tolerance) + return indexer + + def _filter_indexer_tolerance(self, target, indexer, tolerance): + """Adapted from pandas.Index._filter_indexer_tolerance""" + if isinstance(target, pd.Index): + distance = abs(self.values[indexer] - target.values) + else: + distance = abs(self.values[indexer] - target) + indexer = np.where(distance <= tolerance, indexer, -1) + return indexer + def get_loc(self, key, method=None, tolerance=None): """Adapted from pandas.tseries.index.DatetimeIndex.get_loc""" if isinstance(key, str): @@ -427,9 +454,11 @@ def __radd__(self, other): return CFTimeIndex(other + np.array(self)) def __sub__(self, other): - import cftime - - if isinstance(other, (CFTimeIndex, cftime.datetime)): + if _contains_datetime_timedeltas(other): + return CFTimeIndex(np.array(self) - other) + elif isinstance(other, pd.TimedeltaIndex): + return CFTimeIndex(np.array(self) - other.to_pytimedelta()) + elif _contains_cftime_datetimes(np.array(other)): try: return pd.TimedeltaIndex(np.array(self) - np.array(other)) except OverflowError: @@ -437,14 +466,17 @@ def __sub__(self, other): "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." ) - - elif isinstance(other, pd.TimedeltaIndex): - return CFTimeIndex(np.array(self) - other.to_pytimedelta()) else: - return CFTimeIndex(np.array(self) - other) + return NotImplemented def __rsub__(self, other): - return pd.TimedeltaIndex(other - np.array(self)) + try: + return pd.TimedeltaIndex(other - np.array(self)) + except OverflowError: + raise ValueError( + "The time difference exceeds the range of values " + "that can be expressed at the nanosecond resolution." + ) def to_datetimeindex(self, unsafe=False): """If possible, convert this index to a pandas.DatetimeIndex. @@ -633,6 +665,12 @@ def _parse_array_of_cftime_strings(strings, date_type): ).reshape(strings.shape) +def _contains_datetime_timedeltas(array): + """Check if an input array contains datetime.timedelta objects.""" + array = np.atleast_1d(array) + return isinstance(array[0], timedelta) + + def _cftimeindex_from_i8(values, date_type, name): """Construct a CFTimeIndex from an array of integers. diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 8d83b833ca3..43d6d7b068e 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -451,10 +451,21 @@ def test_sel_date_scalar(da, date_type, index): @pytest.mark.xfail(reason="https://github.com/pydata/xarray/issues/3751") +@requires_cftime +def test_sel_date_distant_date(da, date_type, index): + expected = xr.DataArray(4).assign_coords(time=index[3]) + result = da.sel(time=date_type(2000, 1, 1), method="nearest") + assert_identical(result, expected) + + @requires_cftime @pytest.mark.parametrize( "sel_kwargs", - [{"method": "nearest"}, {"method": "nearest", "tolerance": timedelta(days=70)}], + [ + {"method": "nearest"}, + {"method": "nearest", "tolerance": timedelta(days=70)}, + {"method": "nearest", "tolerance": timedelta(days=1800000)}, + ], ) def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): expected = xr.DataArray(2).assign_coords(time=index[1]) @@ -738,7 +749,7 @@ def test_timedeltaindex_add_cftimeindex(calendar): @requires_cftime -def test_cftimeindex_sub(index): +def test_cftimeindex_sub_timedelta(index): date_type = index.date_type expected_dates = [ date_type(1, 1, 2), @@ -753,6 +764,27 @@ def test_cftimeindex_sub(index): assert isinstance(result, CFTimeIndex) +@requires_cftime +@pytest.mark.parametrize( + "other", + [np.array(4 * [timedelta(days=1)]), np.array(timedelta(days=1))], + ids=["1d-array", "scalar-array"], +) +def test_cftimeindex_sub_timedelta_array(index, other): + date_type = index.date_type + expected_dates = [ + date_type(1, 1, 2), + date_type(1, 2, 2), + date_type(2, 1, 2), + date_type(2, 2, 2), + ] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - other + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_sub_cftimeindex(calendar): @@ -784,6 +816,14 @@ def test_cftime_datetime_sub_cftimeindex(calendar): assert isinstance(result, pd.TimedeltaIndex) +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_distant_cftime_datetime_sub_cftimeindex(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + with pytest.raises(ValueError, match="difference exceeds"): + a.date_type(1, 1, 1) - a + + @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) def test_cftimeindex_sub_timedeltaindex(calendar): @@ -795,6 +835,25 @@ def test_cftimeindex_sub_timedeltaindex(calendar): assert isinstance(result, CFTimeIndex) +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_index_of_cftime_datetimes(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + b = pd.Index(a.values) + expected = a - a + result = a - b + assert result.equals(expected) + assert isinstance(result, pd.TimedeltaIndex) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_sub_not_implemented(calendar): + a = xr.cftime_range("2000", periods=5, calendar=calendar) + with pytest.raises(TypeError, match="unsupported operand"): + a - 1 + + @requires_cftime def test_cftimeindex_rsub(index): with pytest.raises(TypeError):