Skip to content

Commit 95a47af

Browse files
dcherianmax-sixty
andauthored
Support dask arrays in datetime_to_numeric (#6556)
Co-authored-by: Maximilian Roos <[email protected]>
1 parent 4615074 commit 95a47af

File tree

2 files changed

+58
-13
lines changed

2 files changed

+58
-13
lines changed

xarray/core/duck_array_ops.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,14 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
431431
# Compute timedelta object.
432432
# For np.datetime64, this can silently yield garbage due to overflow.
433433
# One option is to enforce 1970-01-01 as the universal offset.
434-
array = array - offset
434+
435+
# This map_blocks call is for backwards compatibility.
436+
# dask == 2021.04.1 does not support subtracting object arrays
437+
# which is required for cftime
438+
if is_duck_dask_array(array) and np.issubdtype(array.dtype, np.object):
439+
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
440+
else:
441+
array = array - offset
435442

436443
# Scalar is converted to 0d-array
437444
if not hasattr(array, "dtype"):
@@ -517,10 +524,19 @@ def pd_timedelta_to_float(value, datetime_unit):
517524
return np_timedelta64_to_float(value, datetime_unit)
518525

519526

527+
def _timedelta_to_seconds(array):
528+
return np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6
529+
530+
520531
def py_timedelta_to_float(array, datetime_unit):
521532
"""Convert a timedelta object to a float, possibly at a loss of resolution."""
522-
array = np.asarray(array)
523-
array = np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6
533+
array = asarray(array)
534+
if is_duck_dask_array(array):
535+
array = array.map_blocks(
536+
_timedelta_to_seconds, meta=np.array([], dtype=np.float64)
537+
)
538+
else:
539+
array = _timedelta_to_seconds(array)
524540
conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit)
525541
return conversion_factor * array
526542

xarray/tests/test_duck_array_ops.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -675,39 +675,68 @@ def test_multiple_dims(dtype, dask, skipna, func):
675675
assert_allclose(actual, expected)
676676

677677

678-
def test_datetime_to_numeric_datetime64():
678+
@pytest.mark.parametrize("dask", [True, False])
679+
def test_datetime_to_numeric_datetime64(dask):
680+
if dask and not has_dask:
681+
pytest.skip("requires dask")
682+
679683
times = pd.date_range("2000", periods=5, freq="7D").values
680-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h")
684+
if dask:
685+
import dask.array
686+
687+
times = dask.array.from_array(times, chunks=-1)
688+
689+
with raise_if_dask_computes():
690+
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h")
681691
expected = 24 * np.arange(0, 35, 7)
682692
np.testing.assert_array_equal(result, expected)
683693

684694
offset = times[1]
685-
result = duck_array_ops.datetime_to_numeric(times, offset=offset, datetime_unit="h")
695+
with raise_if_dask_computes():
696+
result = duck_array_ops.datetime_to_numeric(
697+
times, offset=offset, datetime_unit="h"
698+
)
686699
expected = 24 * np.arange(-7, 28, 7)
687700
np.testing.assert_array_equal(result, expected)
688701

689702
dtype = np.float32
690-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype)
703+
with raise_if_dask_computes():
704+
result = duck_array_ops.datetime_to_numeric(
705+
times, datetime_unit="h", dtype=dtype
706+
)
691707
expected = 24 * np.arange(0, 35, 7).astype(dtype)
692708
np.testing.assert_array_equal(result, expected)
693709

694710

695711
@requires_cftime
696-
def test_datetime_to_numeric_cftime():
712+
@pytest.mark.parametrize("dask", [True, False])
713+
def test_datetime_to_numeric_cftime(dask):
714+
if dask and not has_dask:
715+
pytest.skip("requires dask")
716+
697717
times = cftime_range("2000", periods=5, freq="7D", calendar="standard").values
698-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int)
718+
if dask:
719+
import dask.array
720+
721+
times = dask.array.from_array(times, chunks=-1)
722+
with raise_if_dask_computes():
723+
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int)
699724
expected = 24 * np.arange(0, 35, 7)
700725
np.testing.assert_array_equal(result, expected)
701726

702727
offset = times[1]
703-
result = duck_array_ops.datetime_to_numeric(
704-
times, offset=offset, datetime_unit="h", dtype=int
705-
)
728+
with raise_if_dask_computes():
729+
result = duck_array_ops.datetime_to_numeric(
730+
times, offset=offset, datetime_unit="h", dtype=int
731+
)
706732
expected = 24 * np.arange(-7, 28, 7)
707733
np.testing.assert_array_equal(result, expected)
708734

709735
dtype = np.float32
710-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype)
736+
with raise_if_dask_computes():
737+
result = duck_array_ops.datetime_to_numeric(
738+
times, datetime_unit="h", dtype=dtype
739+
)
711740
expected = 24 * np.arange(0, 35, 7).astype(dtype)
712741
np.testing.assert_array_equal(result, expected)
713742

0 commit comments

Comments
 (0)