Skip to content

Commit 9438390

Browse files
dcherianshoyer
authored andcommitted
Add broadcast_like. (#3086)
* Add broadcast_like. Closes #2885 * lint. * lint2 * Use a helper function * lint
1 parent c4497ff commit 9438390

File tree

7 files changed

+143
-44
lines changed

7 files changed

+143
-44
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ Reshaping and reorganizing
212212
Dataset.shift
213213
Dataset.roll
214214
Dataset.sortby
215+
Dataset.broadcast_like
215216

216217
DataArray
217218
=========
@@ -386,6 +387,7 @@ Reshaping and reorganizing
386387
DataArray.shift
387388
DataArray.roll
388389
DataArray.sortby
390+
DataArray.broadcast_like
389391

390392
.. _api.ufuncs:
391393

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ This release increases the minimum required Python version from 3.5.0 to 3.5.3
2424
New functions/methods
2525
~~~~~~~~~~~~~~~~~~~~~
2626

27+
- Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`.
28+
By `Deepak Cherian <https://github.com/dcherian>`_.
29+
2730
Enhancements
2831
~~~~~~~~~~~~
2932

@@ -54,6 +57,7 @@ New functions/methods
5457
(:issue:`3026`).
5558
By `Julia Kent <https://github.com/jukent>`_.
5659

60+
5761
Enhancements
5862
~~~~~~~~~~~~
5963

xarray/core/alignment.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,58 @@ def reindex_variables(
391391
return reindexed, new_indexes
392392

393393

394+
def _get_broadcast_dims_map_common_coords(args, exclude):
395+
396+
common_coords = OrderedDict()
397+
dims_map = OrderedDict()
398+
for arg in args:
399+
for dim in arg.dims:
400+
if dim not in common_coords and dim not in exclude:
401+
dims_map[dim] = arg.sizes[dim]
402+
if dim in arg.coords:
403+
common_coords[dim] = arg.coords[dim].variable
404+
405+
return dims_map, common_coords
406+
407+
408+
def _broadcast_helper(arg, exclude, dims_map, common_coords):
409+
410+
from .dataarray import DataArray
411+
from .dataset import Dataset
412+
413+
def _set_dims(var):
414+
# Add excluded dims to a copy of dims_map
415+
var_dims_map = dims_map.copy()
416+
for dim in exclude:
417+
with suppress(ValueError):
418+
# ignore dim not in var.dims
419+
var_dims_map[dim] = var.shape[var.dims.index(dim)]
420+
421+
return var.set_dims(var_dims_map)
422+
423+
def _broadcast_array(array):
424+
data = _set_dims(array.variable)
425+
coords = OrderedDict(array.coords)
426+
coords.update(common_coords)
427+
return DataArray(data, coords, data.dims, name=array.name,
428+
attrs=array.attrs)
429+
430+
def _broadcast_dataset(ds):
431+
data_vars = OrderedDict(
432+
(k, _set_dims(ds.variables[k]))
433+
for k in ds.data_vars)
434+
coords = OrderedDict(ds.coords)
435+
coords.update(common_coords)
436+
return Dataset(data_vars, coords, ds.attrs)
437+
438+
if isinstance(arg, DataArray):
439+
return _broadcast_array(arg)
440+
elif isinstance(arg, Dataset):
441+
return _broadcast_dataset(arg)
442+
else:
443+
raise ValueError('all input must be Dataset or DataArray objects')
444+
445+
394446
def broadcast(*args, exclude=None):
395447
"""Explicitly broadcast any number of DataArray or Dataset objects against
396448
one another.
@@ -463,55 +515,16 @@ def broadcast(*args, exclude=None):
463515
a (x, y) int64 1 1 2 2 3 3
464516
b (x, y) int64 5 6 5 6 5 6
465517
"""
466-
from .dataarray import DataArray
467-
from .dataset import Dataset
468518

469519
if exclude is None:
470520
exclude = set()
471521
args = align(*args, join='outer', copy=False, exclude=exclude)
472522

473-
common_coords = OrderedDict()
474-
dims_map = OrderedDict()
475-
for arg in args:
476-
for dim in arg.dims:
477-
if dim not in common_coords and dim not in exclude:
478-
dims_map[dim] = arg.sizes[dim]
479-
if dim in arg.coords:
480-
common_coords[dim] = arg.coords[dim].variable
481-
482-
def _set_dims(var):
483-
# Add excluded dims to a copy of dims_map
484-
var_dims_map = dims_map.copy()
485-
for dim in exclude:
486-
with suppress(ValueError):
487-
# ignore dim not in var.dims
488-
var_dims_map[dim] = var.shape[var.dims.index(dim)]
489-
490-
return var.set_dims(var_dims_map)
491-
492-
def _broadcast_array(array):
493-
data = _set_dims(array.variable)
494-
coords = OrderedDict(array.coords)
495-
coords.update(common_coords)
496-
return DataArray(data, coords, data.dims, name=array.name,
497-
attrs=array.attrs)
498-
499-
def _broadcast_dataset(ds):
500-
data_vars = OrderedDict(
501-
(k, _set_dims(ds.variables[k]))
502-
for k in ds.data_vars)
503-
coords = OrderedDict(ds.coords)
504-
coords.update(common_coords)
505-
return Dataset(data_vars, coords, ds.attrs)
506-
523+
dims_map, common_coords = _get_broadcast_dims_map_common_coords(
524+
args, exclude)
507525
result = []
508526
for arg in args:
509-
if isinstance(arg, DataArray):
510-
result.append(_broadcast_array(arg))
511-
elif isinstance(arg, Dataset):
512-
result.append(_broadcast_dataset(arg))
513-
else:
514-
raise ValueError('all input must be Dataset or DataArray objects')
527+
result.append(_broadcast_helper(arg, exclude, dims_map, common_coords))
515528

516529
return tuple(result)
517530

xarray/core/dataarray.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
utils)
1616
from .accessor_dt import DatetimeAccessor
1717
from .accessor_str import StringAccessor
18-
from .alignment import align, reindex_like_indexers
18+
from .alignment import (align, _broadcast_helper,
19+
_get_broadcast_dims_map_common_coords,
20+
reindex_like_indexers)
1921
from .common import AbstractArray, DataWithCoords
2022
from .coordinates import (
2123
DataArrayCoordinates, LevelCoordinatesSource, assert_coordinate_consistent,
@@ -993,6 +995,29 @@ def sel_points(self, dim='points', method=None, tolerance=None,
993995
dim=dim, method=method, tolerance=tolerance, **indexers)
994996
return self._from_temp_dataset(ds)
995997

998+
def broadcast_like(self,
999+
other: Union['DataArray', Dataset],
1000+
exclude=None) -> 'DataArray':
1001+
"""Broadcast this DataArray against another Dataset or DataArray.
1002+
This is equivalent to xr.broadcast(other, self)[1]
1003+
1004+
Parameters
1005+
----------
1006+
other : Dataset or DataArray
1007+
Object against which to broadcast this array.
1008+
exclude : sequence of str, optional
1009+
Dimensions that must not be broadcasted
1010+
"""
1011+
1012+
if exclude is None:
1013+
exclude = set()
1014+
args = align(other, self, join='outer', copy=False, exclude=exclude)
1015+
1016+
dims_map, common_coords = _get_broadcast_dims_map_common_coords(
1017+
args, exclude)
1018+
1019+
return _broadcast_helper(self, exclude, dims_map, common_coords)
1020+
9961021
def reindex_like(self, other: Union['DataArray', Dataset],
9971022
method: Optional[str] = None, tolerance=None,
9981023
copy: bool = True, fill_value=dtypes.NA) -> 'DataArray':

xarray/core/dataset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from ..coding.cftimeindex import _parse_array_of_cftime_strings
1818
from . import (alignment, dtypes, duck_array_ops, formatting, groupby,
1919
indexing, ops, pdcompat, resample, rolling, utils)
20-
from .alignment import align
20+
from .alignment import (align, _broadcast_helper,
21+
_get_broadcast_dims_map_common_coords)
2122
from .common import (ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
2223
_contains_datetime_like_objects)
2324
from .coordinates import (DatasetCoordinates, LevelCoordinatesSource,
@@ -2016,6 +2017,30 @@ def sel_points(self, dim='points', method=None, tolerance=None,
20162017
)
20172018
return self.isel_points(dim=dim, **pos_indexers)
20182019

2020+
def broadcast_like(self,
2021+
other: Union['Dataset', 'DataArray'],
2022+
exclude=None) -> 'Dataset':
2023+
"""Broadcast this DataArray against another Dataset or DataArray.
2024+
This is equivalent to xr.broadcast(other, self)[1]
2025+
2026+
Parameters
2027+
----------
2028+
other : Dataset or DataArray
2029+
Object against which to broadcast this array.
2030+
exclude : sequence of str, optional
2031+
Dimensions that must not be broadcasted
2032+
2033+
"""
2034+
2035+
if exclude is None:
2036+
exclude = set()
2037+
args = align(other, self, join='outer', copy=False, exclude=exclude)
2038+
2039+
dims_map, common_coords = _get_broadcast_dims_map_common_coords(
2040+
args, exclude)
2041+
2042+
return _broadcast_helper(self, exclude, dims_map, common_coords)
2043+
20192044
def reindex_like(self, other, method=None, tolerance=None, copy=True,
20202045
fill_value=dtypes.NA):
20212046
"""Conform this object onto the indexes of another object, filling in

xarray/tests/test_dataarray.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,21 @@ def test_coords_non_string(self):
12551255
expected = DataArray(2, coords={1: 2}, name=1)
12561256
assert_identical(actual, expected)
12571257

1258+
def test_broadcast_like(self):
1259+
original1 = DataArray(np.random.randn(5),
1260+
[('x', range(5))])
1261+
1262+
original2 = DataArray(np.random.randn(6),
1263+
[('y', range(6))])
1264+
1265+
expected1, expected2 = broadcast(original1, original2)
1266+
1267+
assert_identical(original1.broadcast_like(original2),
1268+
expected1.transpose('y', 'x'))
1269+
1270+
assert_identical(original2.broadcast_like(original1),
1271+
expected2)
1272+
12581273
def test_reindex_like(self):
12591274
foo = DataArray(np.random.randn(5, 6),
12601275
[('x', range(5)), ('y', range(6))])

xarray/tests/test_dataset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,6 +1560,21 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
15601560
assert_identical(mdata.sel(x={'one': 'a', 'two': 1}),
15611561
mdata.sel(one='a', two=1))
15621562

1563+
def test_broadcast_like(self):
1564+
original1 = DataArray(np.random.randn(5),
1565+
[('x', range(5))], name='a').to_dataset()
1566+
1567+
original2 = DataArray(np.random.randn(6),
1568+
[('y', range(6))], name='b')
1569+
1570+
expected1, expected2 = broadcast(original1, original2)
1571+
1572+
assert_identical(original1.broadcast_like(original2),
1573+
expected1.transpose('y', 'x'))
1574+
1575+
assert_identical(original2.broadcast_like(original1),
1576+
expected2)
1577+
15631578
def test_reindex_like(self):
15641579
data = create_test_data()
15651580
data['letters'] = ('dim3', 10 * ['a'])

0 commit comments

Comments
 (0)