diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 03898ae1d2a..3f1b2c59b0b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,7 +25,8 @@ New functions/methods ~~~~~~~~~~~~~~~~~~~~~ - Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`. - By `Deepak Cherian `_. + By `Deepak Cherian `_ and `David Mertz + `_. Enhancements ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index aa16a04ec12..99014e9efe1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -998,17 +998,58 @@ def sel_points(self, dim='points', method=None, tolerance=None, def broadcast_like(self, other: Union['DataArray', Dataset], exclude=None) -> 'DataArray': - """Broadcast this DataArray against another Dataset or DataArray. + """Broadcast a DataArray to the shape of another DataArray or Dataset + This is equivalent to xr.broadcast(other, self)[1] + xarray objects are broadcast against each other in arithmetic + operations, so this method is not be necessary for most uses. + + If no change is needed, the input data is returned to the output + without being copied. + + If new coords are added by the broadcast, their values are + NaN filled. + Parameters ---------- other : Dataset or DataArray Object against which to broadcast this array. + exclude : sequence of str, optional Dimensions that must not be broadcasted - """ + Returns + ------- + new_da: xr.DataArray + + Examples + -------- + + >>> arr1 + + array([[0.840235, 0.215216, 0.77917 ], + [0.726351, 0.543824, 0.875115]]) + Coordinates: + * x (x) >> arr2 + + array([[0.612611, 0.125753], + [0.853181, 0.948818], + [0.180885, 0.33363 ]]) + Coordinates: + * x (x) >> arr1.broadcast_like(arr2) + + array([[0.840235, 0.215216, 0.77917 ], + [0.726351, 0.543824, 0.875115], + [ nan, nan, nan]]) + Coordinates: + * x (x) object 'a' 'b' 'c' + * y (y) object 'a' 'b' 'c' + """ if exclude is None: exclude = set() args = align(other, self, join='outer', copy=False, exclude=exclude) @@ -1016,7 +1057,7 @@ def broadcast_like(self, dims_map, common_coords = _get_broadcast_dims_map_common_coords( args, exclude) - return _broadcast_helper(self, exclude, dims_map, common_coords) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def reindex_like(self, other: Union['DataArray', Dataset], method: Optional[str] = None, tolerance=None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 775fdd497b6..d00ea1e4acd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2039,7 +2039,7 @@ def broadcast_like(self, dims_map, common_coords = _get_broadcast_dims_map_common_coords( args, exclude) - return _broadcast_helper(self, exclude, dims_map, common_coords) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def reindex_like(self, other, method=None, tolerance=None, copy=True, fill_value=dtypes.NA): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 63317519bc7..54b22fd336d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1256,19 +1256,23 @@ def test_coords_non_string(self): assert_identical(actual, expected) def test_broadcast_like(self): - original1 = DataArray(np.random.randn(5), - [('x', range(5))]) - - original2 = DataArray(np.random.randn(6), - [('y', range(6))]) - - expected1, expected2 = broadcast(original1, original2) - - assert_identical(original1.broadcast_like(original2), - expected1.transpose('y', 'x')) - - assert_identical(original2.broadcast_like(original1), - expected2) + arr1 = DataArray(np.ones((2, 3)), dims=['x', 'y'], + coords={'x': ['a', 'b'], 'y': ['a', 'b', 'c']}) + arr2 = DataArray(np.ones((3, 2)), dims=['x', 'y'], + coords={'x': ['a', 'b', 'c'], 'y': ['a', 'b']}) + orig1, orig2 = broadcast(arr1, arr2) + new1 = arr1.broadcast_like(arr2) + new2 = arr2.broadcast_like(arr1) + + assert orig1.identical(new1) + assert orig2.identical(new2) + + orig3 = DataArray(np.random.randn(5), [('x', range(5))]) + orig4 = DataArray(np.random.randn(6), [('y', range(6))]) + new3, new4 = broadcast(orig3, orig4) + + assert_identical(orig3.broadcast_like(orig4), new3.transpose('y', 'x')) + assert_identical(orig4.broadcast_like(orig3), new4) def test_reindex_like(self): foo = DataArray(np.random.randn(5, 6),