Skip to content

Commit 11f89ec

Browse files
etienneschalkdcherianpre-commit-ci[bot]
authored
Do not attempt to broadcast when global option arithmetic_broadcast=False (#8784)
* increase plot size * added old tests * Keep relevant test * what's new * PR comment * remove unnecessary (?) check * unnecessary line removal * removal of variable reassignment to avoid type issue * Update xarray/core/variable.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/variable.py Co-authored-by: Deepak Cherian <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests * what's new * Update doc/whats-new.rst --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 14fe7e0 commit 11f89ec

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False``
27+
(:issue:`6806`, :pull:`8784`).
28+
By `Etienne Schalk <https://github.com/etienneschalk>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
2629
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
2730
By `Anderson Banihirwe <https://github.com/andersy005>`_.
2831

xarray/core/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
]
3535

3636
class T_Options(TypedDict):
37+
arithmetic_broadcast: bool
3738
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
3839
cmap_divergent: str | Colormap
3940
cmap_sequential: str | Colormap
@@ -59,6 +60,7 @@ class T_Options(TypedDict):
5960

6061

6162
OPTIONS: T_Options = {
63+
"arithmetic_broadcast": True,
6264
"arithmetic_join": "inner",
6365
"cmap_divergent": "RdBu_r",
6466
"cmap_sequential": "viridis",
@@ -92,6 +94,7 @@ def _positive_integer(value: int) -> bool:
9294

9395

9496
_VALIDATORS = {
97+
"arithmetic_broadcast": lambda value: isinstance(value, bool),
9598
"arithmetic_join": _JOIN_OPTIONS.__contains__,
9699
"display_max_rows": _positive_integer,
97100
"display_values_threshold": _positive_integer,

xarray/core/variable.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,6 +2871,16 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:
28712871

28722872

28732873
def _broadcast_compat_data(self, other):
2874+
if not OPTIONS["arithmetic_broadcast"]:
2875+
if (isinstance(other, Variable) and self.dims != other.dims) or (
2876+
is_duck_array(other) and self.ndim != other.ndim
2877+
):
2878+
raise ValueError(
2879+
"Broadcasting is necessary but automatic broadcasting is disabled via "
2880+
"global option `'arithmetic_broadcast'`. "
2881+
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
2882+
)
2883+
28742884
if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
28752885
# `other` satisfies the necessary Variable API for broadcast_variables
28762886
new_self, new_other = _broadcast_compat_variables(self, other)

xarray/tests/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def _importorskip(
8989
has_pynio, requires_pynio = _importorskip("Nio")
9090
has_cftime, requires_cftime = _importorskip("cftime")
9191
has_dask, requires_dask = _importorskip("dask")
92+
with warnings.catch_warnings():
93+
warnings.filterwarnings(
94+
"ignore",
95+
message="The current Dask DataFrame implementation is deprecated.",
96+
category=DeprecationWarning,
97+
)
98+
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
9299
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
93100
has_rasterio, requires_rasterio = _importorskip("rasterio")
94101
has_zarr, requires_zarr = _importorskip("zarr")

xarray/tests/test_dataarray.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
requires_bottleneck,
5252
requires_cupy,
5353
requires_dask,
54+
requires_dask_expr,
5455
requires_iris,
5556
requires_numexpr,
5657
requires_pint,
@@ -3203,6 +3204,42 @@ def test_align_str_dtype(self) -> None:
32033204
assert_identical(expected_b, actual_b)
32043205
assert expected_b.x.dtype == actual_b.x.dtype
32053206

3207+
def test_broadcast_on_vs_off_global_option_different_dims(self) -> None:
3208+
xda_1 = xr.DataArray([1], dims="x1")
3209+
xda_2 = xr.DataArray([1], dims="x2")
3210+
3211+
with xr.set_options(arithmetic_broadcast=True):
3212+
expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2"))
3213+
actual_xda = xda_1 / xda_2
3214+
assert_identical(actual_xda, expected_xda)
3215+
3216+
with xr.set_options(arithmetic_broadcast=False):
3217+
with pytest.raises(
3218+
ValueError,
3219+
match=re.escape(
3220+
"Broadcasting is necessary but automatic broadcasting is disabled via "
3221+
"global option `'arithmetic_broadcast'`. "
3222+
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
3223+
),
3224+
):
3225+
xda_1 / xda_2
3226+
3227+
@pytest.mark.parametrize("arithmetic_broadcast", [True, False])
3228+
def test_broadcast_on_vs_off_global_option_same_dims(
3229+
self, arithmetic_broadcast: bool
3230+
) -> None:
3231+
# Ensure that no error is raised when arithmetic broadcasting is disabled,
3232+
# when broadcasting is not needed. The two DataArrays have the same
3233+
# dimensions of the same size.
3234+
xda_1 = xr.DataArray([1], dims="x")
3235+
xda_2 = xr.DataArray([1], dims="x")
3236+
expected_xda = xr.DataArray([2.0], dims=("x",))
3237+
3238+
with xr.set_options(arithmetic_broadcast=arithmetic_broadcast):
3239+
assert_identical(xda_1 + xda_2, expected_xda)
3240+
assert_identical(xda_1 + np.array([1.0]), expected_xda)
3241+
assert_identical(np.array([1.0]) + xda_1, expected_xda)
3242+
32063243
def test_broadcast_arrays(self) -> None:
32073244
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
32083245
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
@@ -3381,6 +3418,7 @@ def test_to_dataframe_0length(self) -> None:
33813418
assert len(actual) == 0
33823419
assert_array_equal(actual.index.names, list("ABC"))
33833420

3421+
@requires_dask_expr
33843422
@requires_dask
33853423
def test_to_dask_dataframe(self) -> None:
33863424
arr_np = np.arange(3 * 4).reshape(3, 4)

0 commit comments

Comments
 (0)