From e0f8eb17f02215f1c9a4c02cf9e313911bb71a06 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 26 Oct 2022 17:33:29 -0400 Subject: [PATCH 01/10] better tests, use modified attrs[1] --- xarray/core/computation.py | 2 +- xarray/tests/test_computation.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6ec38453a4b..f300ab9ca72 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1860,7 +1860,7 @@ def where(cond, x, y, keep_attrs=None): if keep_attrs is True: # keep the attributes of x, the second parameter, by default to # be consistent with the `where` method of `DataArray` and `Dataset` - keep_attrs = lambda attrs, context: getattr(x, "attrs", {}) + keep_attrs = lambda attrs, context: attrs[1] if len(attrs) > 1 else {} # alignment for three arguments is complicated, so don't support it yet return apply_ufunc( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index d93adf08474..d3ff6571a9f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1923,16 +1923,36 @@ def test_where() -> None: def test_where_attrs() -> None: - cond = xr.DataArray([True, False], dims="x", attrs={"attr": "cond"}) - x = xr.DataArray([1, 1], dims="x", attrs={"attr": "x"}) - y = xr.DataArray([0, 0], dims="x", attrs={"attr": "y"}) + cond = xr.DataArray([True, False], coords={"x": [0, 1]}, attrs={"attr": "cond_da"}) + cond["x"].attrs = {"attr": "cond_coord"} + x = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) + x["x"].attrs = {"attr": "x_coord"} + y = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"}) + y["x"].attrs = {"attr": "y_coord"} + + # 3 DataArrays, takes attrs from x actual = xr.where(cond, x, y, keep_attrs=True) - expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"}) + expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) + expected["x"].attrs = {"attr": "x_coord"} assert_identical(expected, actual) - # ensure keep_attrs can handle scalar values + # x as a scalar, takes attrs from y + actual = xr.where(cond, 0, y, keep_attrs=True) + expected = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"}) + expected["x"].attrs = {"attr": "y_coord"} + assert_identical(expected, actual) + + # y as a scalar, takes attrs from x + actual = xr.where(cond, x, 0, keep_attrs=True) + expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) + expected["x"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # x and y as a scalar, takes coord attrs only from cond actual = xr.where(cond, 1, 0, keep_attrs=True) - assert actual.attrs == {} + expected = xr.DataArray([1, 0], coords={"x": [0, 1]}) + expected["x"].attrs = {"attr": "cond_coord"} + assert_identical(expected, actual) @pytest.mark.parametrize( From a7d2611841f6a5e19f4b48066a6e5a39ad01d428 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 26 Oct 2022 17:45:45 -0400 Subject: [PATCH 02/10] add whats new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 37ea949ab9d..5adeb2a6326 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -57,6 +57,8 @@ Bug fixes now reopens the file from scratch for h5netcdf and scipy netCDF backends, rather than reusing a cached version (:issue:`4240`, :issue:`4862`). By `Stephan Hoyer `_. +- Fix handling of coordinate attributes in ``xarray.where``. (:issue:`7220`, :pull:`7229`) + By `Sam Levang `_. Documentation ~~~~~~~~~~~~~ From b426425f539d13df510fe3755b0e8779ff34a2a1 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 27 Oct 2022 09:16:15 -0400 Subject: [PATCH 03/10] update keep_attrs docstring --- xarray/core/computation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f300ab9ca72..f8d34b5930a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1789,7 +1789,8 @@ def where(cond, x, y, keep_attrs=None): y : scalar, array, Variable, DataArray or Dataset values to choose from where `cond` is False keep_attrs : bool or str or callable, optional - How to treat attrs. If True, keep the attrs of `x`. + How to treat attrs. If True, keep the attrs of `x`, + unless `x` is a scalar, then keep the attrs of `y`. Returns ------- From 5fc7e3168f7e401f4371c61d053e8c808f86027d Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 7 Nov 2022 20:34:37 -0500 Subject: [PATCH 04/10] cast to DataArray --- xarray/core/computation.py | 14 ++++++++++---- xarray/tests/test_computation.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f8d34b5930a..9347d202398 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -25,7 +25,7 @@ import numpy as np from . import dtypes, duck_array_ops, utils -from .alignment import align, deep_align +from .alignment import align, broadcast, deep_align from .common import zeros_like from .duck_array_ops import datetime_to_numeric from .indexes import Index, filter_indexes_from_coords @@ -1789,8 +1789,7 @@ def where(cond, x, y, keep_attrs=None): y : scalar, array, Variable, DataArray or Dataset values to choose from where `cond` is False keep_attrs : bool or str or callable, optional - How to treat attrs. If True, keep the attrs of `x`, - unless `x` is a scalar, then keep the attrs of `y`. + How to treat attrs. If True, keep the attrs of `x`. Returns ------- @@ -1856,12 +1855,19 @@ def where(cond, x, y, keep_attrs=None): Dataset.where, DataArray.where : equivalent methods """ + from .dataarray import DataArray + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) if keep_attrs is True: # keep the attributes of x, the second parameter, by default to # be consistent with the `where` method of `DataArray` and `Dataset` - keep_attrs = lambda attrs, context: attrs[1] if len(attrs) > 1 else {} + keep_attrs = lambda attrs, context: attrs[1] + # cast non-xarray objects to DataArray to get empty attrs + cond, x, y = (v if hasattr(v, "attrs") else DataArray(v) for v in [cond, x, y]) + # explicitly broadcast to ensure we also get empty coord attrs + # take coord attrs preferentially from x, then y, then cond + x, y, cond = broadcast(x, y, cond) # alignment for three arguments is complicated, so don't support it yet return apply_ufunc( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index f19c6978979..aee2a4e4840 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1938,9 +1938,9 @@ def test_where_attrs() -> None: expected["x"].attrs = {"attr": "x_coord"} assert_identical(expected, actual) - # x as a scalar, takes attrs from y + # x as a scalar, takes coord attrs only from y actual = xr.where(cond, 0, y, keep_attrs=True) - expected = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"}) + expected = xr.DataArray([0, 0], coords={"x": [0, 1]}) expected["x"].attrs = {"attr": "y_coord"} assert_identical(expected, actual) @@ -1956,6 +1956,12 @@ def test_where_attrs() -> None: expected["x"].attrs = {"attr": "cond_coord"} assert_identical(expected, actual) + # cond and y as a scalar, takes attrs from x + actual = xr.where(True, x, y, keep_attrs=True) + expected = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) + expected["x"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + @pytest.mark.parametrize( "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] From 989e5c3c473273a04a93bb314f53c617f359846e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 7 Nov 2022 20:36:19 -0500 Subject: [PATCH 05/10] whats-new --- doc/whats-new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4682e1370b0..c620c9fa48a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Fix handling of coordinate attributes in ``xarray.where``. (:issue:`7220`, :pull:`7229`) + By `Sam Levang `_. Documentation ~~~~~~~~~~~~~ From a6ba8ecd7aa76d83de3252b876aedf5b66c98d79 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 9 Nov 2022 08:12:31 -0500 Subject: [PATCH 06/10] fix whats new --- doc/whats-new.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c620c9fa48a..b4621175014 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -97,8 +97,6 @@ Bug fixes now reopens the file from scratch for h5netcdf and scipy netCDF backends, rather than reusing a cached version (:issue:`4240`, :issue:`4862`). By `Stephan Hoyer `_. -- Fix handling of coordinate attributes in ``xarray.where``. (:issue:`7220`, :pull:`7229`) - By `Sam Levang `_. - Fixed bug where :py:meth:`Dataset.coarsen.construct` would demote non-dimension coordinates to variables. (:pull:`7233`) By `Tom Nicholas `_. - Raise a TypeError when trying to plot empty data (:issue:`7156`, :pull:`7228`). From 3b0336de851a0a140170a9acb0197da517035e9c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 16 Nov 2022 11:42:43 -0700 Subject: [PATCH 07/10] Update doc/whats-new.rst --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b4621175014..6df44a6c8cc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,7 @@ Deprecations Bug fixes ~~~~~~~~~ -- Fix handling of coordinate attributes in ``xarray.where``. (:issue:`7220`, :pull:`7229`) +- Fix handling of coordinate attributes in :py:func:`where`. (:issue:`7220`, :pull:`7229`) By `Sam Levang `_. Documentation From 4cd4300a250f2b0f23c91c0e3e304f163341a785 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 16 Nov 2022 22:22:41 -0500 Subject: [PATCH 08/10] rebuild attrs after apply_ufunc --- xarray/core/computation.py | 33 +++++++++++------- xarray/tests/test_computation.py | 57 ++++++++++++++++++++++---------- 2 files changed, 60 insertions(+), 30 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0ff8411ac7c..f66eb424f33 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -25,7 +25,7 @@ import numpy as np from . import dtypes, duck_array_ops, utils -from .alignment import align, broadcast, deep_align +from .alignment import align, deep_align from .common import zeros_like from .duck_array_ops import datetime_to_numeric from .indexes import Index, filter_indexes_from_coords @@ -1855,22 +1855,13 @@ def where(cond, x, y, keep_attrs=None): Dataset.where, DataArray.where : equivalent methods """ - from .dataarray import DataArray + from .dataset import Dataset if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - if keep_attrs is True: - # keep the attributes of x, the second parameter, by default to - # be consistent with the `where` method of `DataArray` and `Dataset` - keep_attrs = lambda attrs, context: attrs[1] - # cast non-xarray objects to DataArray to get empty attrs - cond, x, y = (v if hasattr(v, "attrs") else DataArray(v) for v in [cond, x, y]) - # explicitly broadcast to ensure we also get empty coord attrs - # take coord attrs preferentially from x, then y, then cond - x, y, cond = broadcast(x, y, cond) # alignment for three arguments is complicated, so don't support it yet - return apply_ufunc( + result = apply_ufunc( duck_array_ops.where, cond, x, @@ -1881,6 +1872,24 @@ def where(cond, x, y, keep_attrs=None): keep_attrs=keep_attrs, ) + # make sure we have the attrs of x across Dataset, DataArray, and coords + if keep_attrs is True: + if isinstance(y, Dataset) and not isinstance(x, Dataset): + # handle special case where x gets promoted to Dataset + result.attrs = {} + if getattr(x, "name", None) in result.data_vars: + result[x.name].attrs = getattr(x, "attrs", {}) + else: + # otherwise, fill in global attrs and variable attrs (if they exist) + result.attrs = getattr(x, "attrs", {}) + for v in getattr(result, "data_vars", []): + result[v].attrs = getattr(getattr(x, v, None), "attrs", {}) + for c in getattr(result, "coords", []): + # always fill coord attrs of x + result[c].attrs = getattr(getattr(x, c, None), "attrs", {}) + + return result + @overload def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index aee2a4e4840..c38eaaa1874 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1925,41 +1925,62 @@ def test_where() -> None: def test_where_attrs() -> None: - cond = xr.DataArray([True, False], coords={"x": [0, 1]}, attrs={"attr": "cond_da"}) - cond["x"].attrs = {"attr": "cond_coord"} - x = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) - x["x"].attrs = {"attr": "x_coord"} - y = xr.DataArray([0, 0], coords={"x": [0, 1]}, attrs={"attr": "y_da"}) - y["x"].attrs = {"attr": "y_coord"} + cond = xr.DataArray([True, False], coords={"a": [0, 1]}, attrs={"attr": "cond_da"}) + cond["a"].attrs = {"attr": "cond_coord"} + x = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + x["a"].attrs = {"attr": "x_coord"} + y = xr.DataArray([0, 0], coords={"a": [0, 1]}, attrs={"attr": "y_da"}) + y["a"].attrs = {"attr": "y_coord"} # 3 DataArrays, takes attrs from x actual = xr.where(cond, x, y, keep_attrs=True) - expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) - expected["x"].attrs = {"attr": "x_coord"} + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} assert_identical(expected, actual) - # x as a scalar, takes coord attrs only from y + # x as a scalar, takes no attrs actual = xr.where(cond, 0, y, keep_attrs=True) - expected = xr.DataArray([0, 0], coords={"x": [0, 1]}) - expected["x"].attrs = {"attr": "y_coord"} + expected = xr.DataArray([0, 0], coords={"a": [0, 1]}) assert_identical(expected, actual) # y as a scalar, takes attrs from x actual = xr.where(cond, x, 0, keep_attrs=True) - expected = xr.DataArray([1, 0], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) - expected["x"].attrs = {"attr": "x_coord"} + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} assert_identical(expected, actual) - # x and y as a scalar, takes coord attrs only from cond + # x and y as a scalar, takes no attrs actual = xr.where(cond, 1, 0, keep_attrs=True) - expected = xr.DataArray([1, 0], coords={"x": [0, 1]}) - expected["x"].attrs = {"attr": "cond_coord"} + expected = xr.DataArray([1, 0], coords={"a": [0, 1]}) assert_identical(expected, actual) # cond and y as a scalar, takes attrs from x actual = xr.where(True, x, y, keep_attrs=True) - expected = xr.DataArray([1, 1], coords={"x": [0, 1]}, attrs={"attr": "x_da"}) - expected["x"].attrs = {"attr": "x_coord"} + expected = xr.DataArray([1, 1], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + expected["a"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # DataArray and 2 Datasets, takes attrs from x + ds_x = xr.Dataset(data_vars={"x": x}, attrs={"attr": "x_ds"}) + ds_y = xr.Dataset(data_vars={"x": y}, attrs={"attr": "y_ds"}) + actual = xr.where(cond, ds_x, ds_y, keep_attrs=True) + expected = xr.Dataset( + data_vars={ + "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + }, + attrs={"attr": "x_ds"}, + ) + expected["a"].attrs = {"attr": "x_coord"} + assert_identical(expected, actual) + + # 2 DataArrays and 1 Dataset, takes attrs from x + actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True) + expected = xr.Dataset( + data_vars={ + "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) + }, + ) + expected["a"].attrs = {"attr": "x_coord"} assert_identical(expected, actual) From 1247b7bb8c31ae33058ef10bf995126459505213 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 17 Nov 2022 09:44:09 -0500 Subject: [PATCH 09/10] fix mypy --- xarray/tests/test_computation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c38eaaa1874..73889c362fe 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1963,25 +1963,25 @@ def test_where_attrs() -> None: # DataArray and 2 Datasets, takes attrs from x ds_x = xr.Dataset(data_vars={"x": x}, attrs={"attr": "x_ds"}) ds_y = xr.Dataset(data_vars={"x": y}, attrs={"attr": "y_ds"}) - actual = xr.where(cond, ds_x, ds_y, keep_attrs=True) - expected = xr.Dataset( + ds_actual = xr.where(cond, ds_x, ds_y, keep_attrs=True) + ds_expected = xr.Dataset( data_vars={ "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) }, attrs={"attr": "x_ds"}, ) - expected["a"].attrs = {"attr": "x_coord"} - assert_identical(expected, actual) + ds_expected["a"].attrs = {"attr": "x_coord"} + assert_identical(ds_expected, ds_actual) # 2 DataArrays and 1 Dataset, takes attrs from x - actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True) - expected = xr.Dataset( + ds_actual = xr.where(cond, x.rename("x"), ds_y, keep_attrs=True) + ds_expected = xr.Dataset( data_vars={ "x": xr.DataArray([1, 0], coords={"a": [0, 1]}, attrs={"attr": "x_da"}) }, ) - expected["a"].attrs = {"attr": "x_coord"} - assert_identical(expected, actual) + ds_expected["a"].attrs = {"attr": "x_coord"} + assert_identical(ds_expected, ds_actual) @pytest.mark.parametrize( From bfffd7b7d48af7b0381d3d4f535d3f5796af8a01 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 17 Nov 2022 15:22:09 -0500 Subject: [PATCH 10/10] better comment --- xarray/core/computation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f66eb424f33..e7445c0b397 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1872,7 +1872,10 @@ def where(cond, x, y, keep_attrs=None): keep_attrs=keep_attrs, ) - # make sure we have the attrs of x across Dataset, DataArray, and coords + # keep the attributes of x, the second parameter, by default to + # be consistent with the `where` method of `DataArray` and `Dataset` + # rebuild the attrs from x at each level of the output, which could be + # Dataset, DataArray, or Variable, and also handle coords if keep_attrs is True: if isinstance(y, Dataset) and not isinstance(x, Dataset): # handle special case where x gets promoted to Dataset