Skip to content

Recommend optimize=True in xr.dot docstring #8017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 tasks done
cgahr opened this issue Jul 25, 2023 · 7 comments · Fixed by #8373
Closed
4 tasks done

Recommend optimize=True in xr.dot docstring #8017

cgahr opened this issue Jul 25, 2023 · 7 comments · Fixed by #8373

Comments

@cgahr
Copy link

cgahr commented Jul 25, 2023

What happened?

If I multiply several arrays with each other, using da0 @ da1 @ da2 @ ... is much faster compared to using xr.dot(da0, da1, da2, ...).

For me personally it is not clear at all that this actually can be the case, and I only discovered this by accident.

What did you expect to happen?

I would expect that xr.dot is equivalent to da.dot in the sense that xr.dot(da0, da1, ...) is comparable to da0 @ da1 @ ....

Minimal Complete Verifiable Example

import numpy as np
import xarray as xr
import time

gen = np.random.default_rng()

da0 = xr.DataArray(gen.normal(0, 1, (100, 10)), dims=('x', 'rx'))
da1 = xr.DataArray(gen.normal(0, 1, (100, 10)), dims=('y', 'ry'))
da2 = da1.rename(ry="ry'")
da2 = (da1 - da1.roll(y=1)).rename(ry="ry'")
da3 = da0.rename(rx="rx'")

step0 = time.time()
_ = da0 @ da1 @ da2 @ da3
step1 = time.time()
_ = xr.dot(da0, da1, da2, da3)
step2 = time.time()

print(f'@:      {step1 - step0:.3f}s')  # 0.063s
print(f'xr.dot: {step2 - step1:.3f}s')  # 1.249s

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS

commit: None
python: 3.11.3 | packaged by conda-forge | (main, Apr 6 2023, 08:57:19) [GCC 11.3.0]
python-bits: 64
OS: Linux
OS-release: 4.12.14-122.162-default
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.0
libnetcdf: None

xarray: 2023.6.0
pandas: 1.5.3
numpy: 1.25.0
scipy: 1.10.1
netCDF4: None
pydap: None
h5netcdf: 1.2.0
h5py: 3.8.0
Nio: None
zarr: None
cftime: None
nc_time_axis: None
PseudoNetCDF: None
iris: None
bottleneck: None
dask: 2023.4.1
distributed: 2023.4.1
matplotlib: 3.7.1
cartopy: None
seaborn: None
numbagg: None
fsspec: 2023.4.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: None
setuptools: 67.7.2
pip: 23.1.2
conda: None
pytest: 7.4.0
mypy: 1.4.1
IPython: 8.13.2
sphinx: None

@cgahr cgahr added bug needs triage Issue that has not been reviewed by xarray team member labels Jul 25, 2023
@dcherian
Copy link
Contributor

dcherian commented Jul 25, 2023

We use numpy for such things, so this is really up to them (and which BLAS kernel they end up calling).

xr.dot uses einsum. We could speed this up in Xarray by using opt_einsum by default: #7764 . This would be a relatively easy and impactful PR if you're up for it.

EDIT: Funnily enough @ dispatches to einsum too =) This would make a good numpy bug report.

@dcherian dcherian removed bug needs triage Issue that has not been reviewed by xarray team member labels Jul 25, 2023
@Illviljan
Copy link
Contributor

Illviljan commented Jul 25, 2023

xr.dot(xr.dot(xr.dot(da0, da1), da2), da3) is the same speed as da0 @ da1 @ da2 @ da3
image

I suppose we could try a recursive version instead of sending a bunch of arrays to einsum.

import numpy as np
import xarray as xr
import time
import matplotlib.pyplot as plt

gen = np.random.default_rng()


dot_a = []
dot_b = []
dot_c = []
for s in np.arange(1, 4):
    da0 = xr.DataArray(gen.normal(0, 1, (s * 100, 10)), dims=("x", "rx"))
    da1 = xr.DataArray(gen.normal(0, 1, (s * 100, 10)), dims=("y", "ry"))
    da2 = da1.rename(ry="ry'")
    da2 = (da1 - da1.roll(y=1)).rename(ry="ry'")
    da3 = da0.rename(rx="rx'")

    step0 = time.time()
    a = da0 @ da1 @ da2 @ da3
    step1 = time.time()
    b = xr.dot(da0, da1, da2, da3)
    step2 = time.time()
    c = xr.dot(xr.dot(xr.dot(da0, da1), da2), da3)
    step3 = time.time()
    np.testing.assert_allclose(a, b)
    np.testing.assert_allclose(a, c)

    dot_a.append(step1 - step0)
    dot_b.append(step2 - step1)
    dot_c.append(step3 - step2)


fig, ax = plt.subplots(1, 1)
ax.plot(dot_a, label="da0 @ da1 @ da2 @ da3")
ax.plot(dot_b, label="xr.dot(da0, da1, da2, da3)")
ax.plot(dot_c, label="xr.dot(xr.dot(xr.dot(da0, da1), da2), da3)")

ax.legend()

@Illviljan
Copy link
Contributor

Related issues:
numpy/numpy#22604
pytorch/pytorch#32591

@dcherian
Copy link
Contributor

dcherian commented Jul 25, 2023

Ah just pass optimize=True to xr.dot. We should just set that by default. Easy PR!

@:      0.028s
xr.dot: 0.001s

@Illviljan
Copy link
Contributor

Yeah, optimize made quite the difference:
image

import numpy as np
import xarray as xr
import time
import matplotlib.pyplot as plt

gen = np.random.default_rng()


dot_a = []
dot_b = []
dot_c = []
for s in np.arange(1, 4):
    da0 = xr.DataArray(gen.normal(0, 1, (s * 100, 10)), dims=("x", "rx"))
    da1 = xr.DataArray(gen.normal(0, 1, (s * 100, 10)), dims=("y", "ry"))
    da2 = da1.rename(ry="ry'")
    da2 = (da1 - da1.roll(y=1)).rename(ry="ry'")
    da3 = da0.rename(rx="rx'")

    step0 = time.time()
    a = da0 @ da1 @ da2 @ da3
    step1 = time.time()
    b = xr.dot(da0, da1, da2, da3, optimize=True)
    step2 = time.time()
    c = xr.dot(xr.dot(xr.dot(da0, da1), da2), da3)
    step3 = time.time()
    np.testing.assert_allclose(a, b)
    np.testing.assert_allclose(a, c)

    dot_a.append(step1 - step0)
    dot_b.append(step2 - step1)
    dot_c.append(step3 - step2)


fig, ax = plt.subplots(1, 1)
ax.plot(dot_a, marker="x", label="da0 @ da1 @ da2 @ da3")
ax.plot(dot_b, marker="o", label="xr.dot(da0, da1, da2, da3, optimize=True)")
ax.plot(dot_c, marker=".", label="xr.dot(xr.dot(xr.dot(da0, da1), da2), da3)")

ax.legend()

@dcherian
Copy link
Contributor

just pass optimize=True to xr.dot

I now remember why we can't do this. Not all array types support it, e.g. sparse. Perhaps make a recommendation in the docstring?

@cgahr
Copy link
Author

cgahr commented Jul 27, 2023

I can confirm that optimize=True speeds up my computations by a factor of about 6000x.

just pass optimize=True to xr.dot

I now remember why we can't do this. Not all array types support it, e.g. sparse. Perhaps make a recommendation in the docstring?
I'd recommend putting a note into the documentation, so that future users are aware of this.

From my side, you can close this issue, however, you can also leave it open to track if the docstring has been update.

@dcherian dcherian changed the title @ about 20x faster than xr.dot Recommend optimize=True in xr.dot Jul 27, 2023
@dcherian dcherian changed the title Recommend optimize=True in xr.dot Recommend optimize=True in xr.dot docstring Jul 27, 2023
dcherian added a commit to dcherian/xarray that referenced this issue Oct 25, 2023
dcherian added a commit that referenced this issue Oct 28, 2023
* Use `opt_einsum` by default if installed.

Closes #7764
Closes #8017

* docstring update

* _

* _

Co-authored-by: Maximilian Roos <[email protected]>

* Update xarray/core/computation.py

Co-authored-by: Maximilian Roos <[email protected]>

* Fix docs?

* Add use_opt_einsum option.

* mypy ignore

* one more test ignore

* Disable navigation_with_keys

* remove intersphinx

* One more skip

---------

Co-authored-by: Maximilian Roos <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants