Skip to content

Do not attempt to broadcast when global option arithmetic_broadcast=False #8784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False``
(:issue:`6806`, :pull:`8784`).
By `Etienne Schalk <https://github.com/etienneschalk>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

Expand Down
3 changes: 3 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
]

class T_Options(TypedDict):
arithmetic_broadcast: bool
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
Expand All @@ -59,6 +60,7 @@ class T_Options(TypedDict):


OPTIONS: T_Options = {
"arithmetic_broadcast": True,
"arithmetic_join": "inner",
"cmap_divergent": "RdBu_r",
"cmap_sequential": "viridis",
Expand Down Expand Up @@ -92,6 +94,7 @@ def _positive_integer(value: int) -> bool:


_VALIDATORS = {
"arithmetic_broadcast": lambda value: isinstance(value, bool),
"arithmetic_join": _JOIN_OPTIONS.__contains__,
"display_max_rows": _positive_integer,
"display_values_threshold": _positive_integer,
Expand Down
10 changes: 10 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,6 +2871,16 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:


def _broadcast_compat_data(self, other):
if not OPTIONS["arithmetic_broadcast"]:
if (isinstance(other, Variable) and self.dims != other.dims) or (
is_duck_array(other) and self.ndim != other.ndim
):
raise ValueError(
"Broadcasting is necessary but automatic broadcasting is disabled via "
"global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
)

if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
# `other` satisfies the necessary Variable API for broadcast_variables
new_self, new_other = _broadcast_compat_variables(self, other)
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def _importorskip(
has_pynio, requires_pynio = _importorskip("Nio")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The current Dask DataFrame implementation is deprecated.",
category=DeprecationWarning,
)
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
Expand Down
38 changes: 38 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
requires_bottleneck,
requires_cupy,
requires_dask,
requires_dask_expr,
requires_iris,
requires_numexpr,
requires_pint,
Expand Down Expand Up @@ -3203,6 +3204,42 @@ def test_align_str_dtype(self) -> None:
assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_on_vs_off_global_option_different_dims(self) -> None:
xda_1 = xr.DataArray([1], dims="x1")
xda_2 = xr.DataArray([1], dims="x2")

with xr.set_options(arithmetic_broadcast=True):
expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2"))
actual_xda = xda_1 / xda_2
assert_identical(actual_xda, expected_xda)

with xr.set_options(arithmetic_broadcast=False):
with pytest.raises(
ValueError,
match=re.escape(
"Broadcasting is necessary but automatic broadcasting is disabled via "
"global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
),
):
xda_1 / xda_2

@pytest.mark.parametrize("arithmetic_broadcast", [True, False])
def test_broadcast_on_vs_off_global_option_same_dims(
self, arithmetic_broadcast: bool
) -> None:
# Ensure that no error is raised when arithmetic broadcasting is disabled,
# when broadcasting is not needed. The two DataArrays have the same
# dimensions of the same size.
xda_1 = xr.DataArray([1], dims="x")
xda_2 = xr.DataArray([1], dims="x")
expected_xda = xr.DataArray([2.0], dims=("x",))

with xr.set_options(arithmetic_broadcast=arithmetic_broadcast):
assert_identical(xda_1 + xda_2, expected_xda)
assert_identical(xda_1 + np.array([1.0]), expected_xda)
assert_identical(np.array([1.0]) + xda_1, expected_xda)

def test_broadcast_arrays(self) -> None:
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down Expand Up @@ -3381,6 +3418,7 @@ def test_to_dataframe_0length(self) -> None:
assert len(actual) == 0
assert_array_equal(actual.index.names, list("ABC"))

@requires_dask_expr
@requires_dask
def test_to_dask_dataframe(self) -> None:
arr_np = np.arange(3 * 4).reshape(3, 4)
Expand Down