Skip to content

Commit 9f5e31b

Browse files
committed
backcompat: vendor np.broadcast_shapes
1 parent 8df0c2a commit 9f5e31b

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

xarray/core/npcompat.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,66 @@ def sliding_window_view(
245245
"midpoint",
246246
"nearest",
247247
]
248+
249+
250+
if Version(np.__version__) < Version("1.20"):
251+
252+
def _broadcast_shape(*args):
253+
"""Returns the shape of the arrays that would result from broadcasting the
254+
supplied arrays against each other.
255+
"""
256+
# use the old-iterator because np.nditer does not handle size 0 arrays
257+
# consistently
258+
b = np.broadcast(*args[:32])
259+
# unfortunately, it cannot handle 32 or more arguments directly
260+
for pos in range(32, len(args), 31):
261+
# ironically, np.broadcast does not properly handle np.broadcast
262+
# objects (it treats them as scalars)
263+
# use broadcasting to avoid allocating the full array
264+
b = np.broadcast_to(0, b.shape)
265+
b = np.broadcast(b, *args[pos : (pos + 31)])
266+
return b.shape
267+
268+
def broadcast_shapes(*args):
269+
"""
270+
Broadcast the input shapes into a single shape.
271+
272+
:ref:`Learn more about broadcasting here <basics.broadcasting>`.
273+
274+
.. versionadded:: 1.20.0
275+
276+
Parameters
277+
----------
278+
`*args` : tuples of ints, or ints
279+
The shapes to be broadcast against each other.
280+
281+
Returns
282+
-------
283+
tuple
284+
Broadcasted shape.
285+
286+
Raises
287+
------
288+
ValueError
289+
If the shapes are not compatible and cannot be broadcast according
290+
to NumPy's broadcasting rules.
291+
292+
See Also
293+
--------
294+
broadcast
295+
broadcast_arrays
296+
broadcast_to
297+
298+
Examples
299+
--------
300+
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
301+
(3, 2)
302+
303+
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
304+
(5, 6, 7)
305+
"""
306+
arrays = [np.empty(x, dtype=[]) for x in args]
307+
return _broadcast_shape(*arrays)
308+
309+
else:
310+
from numpy import broadcast_shapes # noqa

xarray/core/nputils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
88

9+
from .npcompat import broadcast_shapes
910
from .options import OPTIONS
1011

1112
try:
@@ -109,7 +110,7 @@ def _advanced_indexer_subspaces(key):
109110
return (), ()
110111

111112
non_slices = [k for k in key if not isinstance(k, slice)]
112-
ndim = len(np.broadcast_shapes(*[item.shape for item in non_slices]))
113+
ndim = len(broadcast_shapes(*[item.shape for item in non_slices]))
113114
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
114115
vindex_positions = np.arange(ndim)
115116
return mixed_positions, vindex_positions

0 commit comments

Comments
 (0)