diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh
index 41507fce13e..97ae4c2bbca 100755
--- a/ci/install-upstream-wheels.sh
+++ b/ci/install-upstream-wheels.sh
@@ -45,4 +45,5 @@ python -m pip install \
git+https://github.com/intake/filesystem_spec \
git+https://github.com/SciTools/nc-time-axis \
git+https://github.com/xarray-contrib/flox \
- git+https://github.com/h5netcdf/h5netcdf
+ git+https://github.com/h5netcdf/h5netcdf \
+ git+https://github.com/dgasmith/opt_einsum
diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml
index dd73ef19658..6e93ab7a946 100644
--- a/ci/requirements/environment.yml
+++ b/ci/requirements/environment.yml
@@ -26,6 +26,7 @@ dependencies:
- numbagg
- numexpr
- numpy
+ - opt_einsum
- packaging
- pandas
- pint<0.21
diff --git a/doc/conf.py b/doc/conf.py
index 74c41b52ab6..17d150ae6cd 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -236,6 +236,7 @@
use_repository_button=True,
use_issues_button=True,
home_page_in_toc=False,
+ navigation_with_keys=False,
extra_footer="""
Xarray is a fiscally sponsored project of NumFOCUS,
a nonprofit dedicated to supporting the open-source scientific computing community.
Theme by the Executable Book Project
""",
@@ -326,6 +327,7 @@
"sparse": ("https://sparse.pydata.org/en/latest/", None),
"cubed": ("https://tom-e-white.com/cubed/", None),
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
+ # "opt_einsum": ("https://dgasmith.github.io/opt_einsum/", None),
}
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 9f049b148c7..5bc63269ffd 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -22,6 +22,8 @@ v2023.10.2 (unreleased)
New Features
~~~~~~~~~~~~
+- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed.
+ By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`).
Breaking changes
~~~~~~~~~~~~~~~~
diff --git a/pyproject.toml b/pyproject.toml
index e7fa7bec5c0..b16063e0370 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,7 +38,7 @@ source-code = "https://github.com/pydata/xarray"
dask = "xarray.core.daskmanager:DaskManager"
[project.optional-dependencies]
-accel = ["scipy", "bottleneck", "numbagg", "flox"]
+accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz]"]
io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"]
parallel = ["dask[complete]"]
@@ -106,6 +106,7 @@ module = [
"numbagg.*",
"netCDF4.*",
"netcdftime.*",
+ "opt_einsum.*",
"pandas.*",
"pooch.*",
"PseudoNetCDF.*",
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index 9cb60e0c424..70b30ae2176 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -1690,8 +1690,8 @@ def dot(
dims: Dims = None,
**kwargs: Any,
):
- """Generalized dot product for xarray objects. Like np.einsum, but
- provides a simpler interface based on array dimensions.
+ """Generalized dot product for xarray objects. Like ``np.einsum``, but
+ provides a simpler interface based on array dimension names.
Parameters
----------
@@ -1701,13 +1701,24 @@ def dot(
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
- Additional keyword arguments passed to numpy.einsum or
- dask.array.einsum
+ Additional keyword arguments passed to ``numpy.einsum`` or
+ ``dask.array.einsum``
Returns
-------
DataArray
+ See Also
+ --------
+ numpy.einsum
+ dask.array.einsum
+ opt_einsum.contract
+
+ Notes
+ -----
+ We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``,
+ which is passed through to ``np.einsum``, and works for most array backends.
+
Examples
--------
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"])
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index 51b6ff5f59b..b9f7db9737f 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -18,7 +18,6 @@
from numpy import any as array_any # noqa
from numpy import ( # noqa
around, # noqa
- einsum,
gradient,
isclose,
isin,
@@ -48,6 +47,17 @@ def get_array_namespace(x):
return np
+def einsum(*args, **kwargs):
+ from xarray.core.options import OPTIONS
+
+ if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"):
+ import opt_einsum
+
+ return opt_einsum.contract(*args, **kwargs)
+ else:
+ return np.einsum(*args, **kwargs)
+
+
def _dask_or_eager_func(
name,
eager_module=np,
diff --git a/xarray/core/options.py b/xarray/core/options.py
index 118a67559ad..d116c350991 100644
--- a/xarray/core/options.py
+++ b/xarray/core/options.py
@@ -28,6 +28,7 @@
"warn_for_unclosed_files",
"use_bottleneck",
"use_numbagg",
+ "use_opt_einsum",
"use_flox",
]
@@ -52,6 +53,7 @@ class T_Options(TypedDict):
use_bottleneck: bool
use_flox: bool
use_numbagg: bool
+ use_opt_einsum: bool
OPTIONS: T_Options = {
@@ -75,6 +77,7 @@ class T_Options(TypedDict):
"use_bottleneck": True,
"use_flox": True,
"use_numbagg": True,
+ "use_opt_einsum": True,
}
_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
@@ -102,6 +105,7 @@ def _positive_integer(value: int) -> bool:
"keep_attrs": lambda choice: choice in [True, False, "default"],
"use_bottleneck": lambda value: isinstance(value, bool),
"use_numbagg": lambda value: isinstance(value, bool),
+ "use_opt_einsum": lambda value: isinstance(value, bool),
"use_flox": lambda value: isinstance(value, bool),
"warn_for_unclosed_files": lambda value: isinstance(value, bool),
}
@@ -237,6 +241,8 @@ class set_options:
use_numbagg : bool, default: True
Whether to use ``numbagg`` to accelerate reductions.
Takes precedence over ``use_bottleneck`` when both are True.
+ use_opt_einsum : bool, default: True
+ Whether to use ``opt_einsum`` to accelerate dot products.
warn_for_unclosed_files : bool, default: False
Whether or not to issue a warning when unclosed files are
deallocated. This is mostly useful for debugging.
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index 7e1105e2e5d..14a7a10f734 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -1502,10 +1502,11 @@ def test_dot_dataarray(dtype):
data_array = xr.DataArray(data=array1, dims=("x", "y"))
other = xr.DataArray(data=array2, dims=("y", "z"))
- expected = attach_units(
- xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
- )
- actual = xr.dot(data_array, other)
+ with xr.set_options(use_opt_einsum=False):
+ expected = attach_units(
+ xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
+ )
+ actual = xr.dot(data_array, other)
assert_units_equal(expected, actual)
assert_identical(expected, actual)
@@ -2465,8 +2466,9 @@ def test_binary_operations(self, func, dtype):
data_array = xr.DataArray(data=array)
units = extract_units(func(array))
- expected = attach_units(func(strip_units(data_array)), units)
- actual = func(data_array)
+ with xr.set_options(use_opt_einsum=False):
+ expected = attach_units(func(strip_units(data_array)), units)
+ actual = func(data_array)
assert_units_equal(expected, actual)
assert_identical(expected, actual)
@@ -3829,8 +3831,9 @@ def test_computation(self, func, variant, dtype):
if not isinstance(func, (function, method)):
units.update(extract_units(func(array.reshape(-1))))
- expected = attach_units(func(strip_units(data_array)), units)
- actual = func(data_array)
+ with xr.set_options(use_opt_einsum=False):
+ expected = attach_units(func(strip_units(data_array)), units)
+ actual = func(data_array)
assert_units_equal(expected, actual)
assert_identical(expected, actual)