|
7 | 7 | from . import (
|
8 | 8 | assert_array_equal,
|
9 | 9 | assert_equal,
|
| 10 | + assert_identical, |
10 | 11 | raises_regex,
|
11 | 12 | requires_cftime,
|
12 | 13 | requires_dask,
|
@@ -435,3 +436,106 @@ def test_seasons(cftime_date_type):
|
435 | 436 | seasons = xr.DataArray(seasons)
|
436 | 437 |
|
437 | 438 | assert_array_equal(seasons.values, dates.dt.season.values)
|
| 439 | + |
| 440 | + |
| 441 | +@pytest.fixture |
| 442 | +def cftime_rounding_dataarray(cftime_date_type): |
| 443 | + return xr.DataArray( |
| 444 | + [ |
| 445 | + [cftime_date_type(1, 1, 1, 1), cftime_date_type(1, 1, 1, 15)], |
| 446 | + [cftime_date_type(1, 1, 1, 23), cftime_date_type(1, 1, 2, 1)], |
| 447 | + ] |
| 448 | + ) |
| 449 | + |
| 450 | + |
| 451 | +@requires_cftime |
| 452 | +@requires_dask |
| 453 | +@pytest.mark.parametrize("use_dask", [False, True]) |
| 454 | +def test_cftime_floor_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): |
| 455 | + import dask.array as da |
| 456 | + |
| 457 | + freq = "D" |
| 458 | + expected = xr.DataArray( |
| 459 | + [ |
| 460 | + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 1, 0)], |
| 461 | + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)], |
| 462 | + ], |
| 463 | + name="floor", |
| 464 | + ) |
| 465 | + |
| 466 | + if use_dask: |
| 467 | + chunks = {"dim_0": 1} |
| 468 | + # Currently a compute is done to inspect a single value of the array |
| 469 | + # if it is of object dtype to check if it is a cftime.datetime (if not |
| 470 | + # we raise an error when using the dt accessor). |
| 471 | + with raise_if_dask_computes(max_computes=1): |
| 472 | + result = cftime_rounding_dataarray.chunk(chunks).dt.floor(freq) |
| 473 | + expected = expected.chunk(chunks) |
| 474 | + assert isinstance(result.data, da.Array) |
| 475 | + assert result.chunks == expected.chunks |
| 476 | + else: |
| 477 | + result = cftime_rounding_dataarray.dt.floor(freq) |
| 478 | + |
| 479 | + assert_identical(result, expected) |
| 480 | + |
| 481 | + |
| 482 | +@requires_cftime |
| 483 | +@requires_dask |
| 484 | +@pytest.mark.parametrize("use_dask", [False, True]) |
| 485 | +def test_cftime_ceil_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): |
| 486 | + import dask.array as da |
| 487 | + |
| 488 | + freq = "D" |
| 489 | + expected = xr.DataArray( |
| 490 | + [ |
| 491 | + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)], |
| 492 | + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 3, 0)], |
| 493 | + ], |
| 494 | + name="ceil", |
| 495 | + ) |
| 496 | + |
| 497 | + if use_dask: |
| 498 | + chunks = {"dim_0": 1} |
| 499 | + # Currently a compute is done to inspect a single value of the array |
| 500 | + # if it is of object dtype to check if it is a cftime.datetime (if not |
| 501 | + # we raise an error when using the dt accessor). |
| 502 | + with raise_if_dask_computes(max_computes=1): |
| 503 | + result = cftime_rounding_dataarray.chunk(chunks).dt.ceil(freq) |
| 504 | + expected = expected.chunk(chunks) |
| 505 | + assert isinstance(result.data, da.Array) |
| 506 | + assert result.chunks == expected.chunks |
| 507 | + else: |
| 508 | + result = cftime_rounding_dataarray.dt.ceil(freq) |
| 509 | + |
| 510 | + assert_identical(result, expected) |
| 511 | + |
| 512 | + |
| 513 | +@requires_cftime |
| 514 | +@requires_dask |
| 515 | +@pytest.mark.parametrize("use_dask", [False, True]) |
| 516 | +def test_cftime_round_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): |
| 517 | + import dask.array as da |
| 518 | + |
| 519 | + freq = "D" |
| 520 | + expected = xr.DataArray( |
| 521 | + [ |
| 522 | + [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)], |
| 523 | + [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)], |
| 524 | + ], |
| 525 | + name="round", |
| 526 | + ) |
| 527 | + |
| 528 | + if use_dask: |
| 529 | + chunks = {"dim_0": 1} |
| 530 | + # Currently a compute is done to inspect a single value of the array |
| 531 | + # if it is of object dtype to check if it is a cftime.datetime (if not |
| 532 | + # we raise an error when using the dt accessor). |
| 533 | + with raise_if_dask_computes(max_computes=1): |
| 534 | + result = cftime_rounding_dataarray.chunk(chunks).dt.round(freq) |
| 535 | + expected = expected.chunk(chunks) |
| 536 | + assert isinstance(result.data, da.Array) |
| 537 | + assert result.chunks == expected.chunks |
| 538 | + else: |
| 539 | + result = cftime_rounding_dataarray.dt.round(freq) |
| 540 | + |
| 541 | + assert_identical(result, expected) |
0 commit comments