Skip to content

Commit 8dac64b

Browse files
Illviljanpre-commit-ci[bot]mathauseandersy005
authored
Add dataarray scatter (#6778)
* allow adding any number of extra coords * Explain how ds will becom darray * Update dataset_plot.py * Update dataset_plot.py * use coords for coords * Explain goal of moving ds plots to da * Update dataset_plot.py * Update dataset_plot.py * Update dataset_plot.py * handle non-existant coords * Update dataset_plot.py * Look through the kwargs to find extra coords * output of legend labels has changed * pop plt, comment out error test * Update dataset_plot.py * Update facetgrid.py * move some funcs to utils * add the funcs to the moved place * various bugfixes * use coords to check if valid * only normalize sizes, hue is not necessary. * Use same scatter parameter order as the dataset version. * Fix tests assuming a list of patchollections is returned. * improve ds to da wrapper * Filter kwargs * normalize args to be able to filter the correct args * Update plot.py * Update plot.py * Update dataset_plot.py * Some fixes to string colorbar * Update plot.py * Check if hue is str * Fix some failing tests * Update dataset_plot.py * Add more relevant params higher up * use hue in facetgrid, normalize data * Update plot.py * Move parts of scatter to a decorator * Update plot.py * Update plot.py * get scatter to work with decorator * use correct name * Add a Normalize class For categoricals to work most of the time a normalization to numerics has to be done. Once shown on the plot it has to be reformatted however with a lookup function * skip use of Literal * remove test code * fix lint errors * more linting fixes * doctests fixing * Update utils.py * Update plot.py * Update utils.py * Update plot.py * Update facetgrid.py * revert some old ideas * Update utils.py * Update plot.py * trim unused code * use to_numpy instead * more pint compats * work on facetgrid legends * facetgrid colorbar tweaks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Categoricals starts on 1 and is bounded 0,2 This makes plt.colorbar return ticks in the center of the color * Handle None in Normalize * Fix labels * Update plot.py * determine guide * fix plt * Update facetgrid.py * Don't be able to plot empty legends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try out linecollection so lines behaves similar to scatter * linecollections half working * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * A few variations of linecollection * linecollection can behave as scatter, with hue and size, But which part of the array will be considered a line and how do you filter for that? * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * line to utils * line plot changes * reshape to get hues working * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * line edits legend not nice on line plots yet * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tutorial.py * doc changes, tuple to dict * nice line plots and working legend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * comment out some variants * some cleanup * Guess some dims if they weren't defined * None is supposed to pass as well * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make precommit happy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add hist, step * handle step using repeat, remove pint errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * handle pint * fix some tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use isel instead to be independent of categoricals or not * allow multiple primitives and filter duplicates * Update test_plot.py * Copy data inside instead at init. * Histograms has counted values along y, switch around x and y labels. * output as numpy array * histogram outputs primitive only * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update facetgrid.py * use add_labels inputs, explicit indexes now handles attrs * colorbar in correct position * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Avoid always stacking To avoid adding unnecessary NaNs. * linecollection fixes TODO is to make sure the values are plotted the along the same axis. * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add datarray scatter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * Update plot.py * out of scope stuff * Update test_plot.py * Update plot.py * fix some tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update whats-new.rst * Update utils.py * Update xarray/plot/facetgrid.py Co-authored-by: Mathias Hauser <[email protected]> * Update plot.py * typo * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * Update xarray/plot/utils.py Co-authored-by: Mathias Hauser <[email protected]> * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * some typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update facetgrid.py * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Convert name to string in label_from_attrs * Update whats-new.rst * Add typing to soem interval funcs * undo parse_size edits, not necessary * ax not needed * Add some typing * Update utils.py * Cleaner retrieval of add_labels and * type hints * Fix facetgrid and normal plot not matching * Update facetgrid.py * Update plot.py * Add typing to dataset funcs + some fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataset_plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add type hints to plot1d * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update facetgrid.py * Update facetgrid.py * Update facetgrid.py * remove sharex for 3d plots, not supported. Add set_lims so all data in plots are shown * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update facetgrid.py * Update facetgrid.py * fix typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Self should be any * more fixes to typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update facetgrid.py * fix some mypy errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plot.py * Update whats-new.rst Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent 50ea159 commit 8dac64b

File tree

6 files changed

+1176
-590
lines changed

6 files changed

+1176
-590
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ v2022.09.1 (unreleased)
2222
New Features
2323
~~~~~~~~~~~~
2424

25+
- Add scatter plot for datarrays. Scatter plots now also supports 3d plots with
26+
the z argument. (:pull:`6778`)
27+
By `Jimmy Westling <https://github.com/illviljan>`_.
2528

2629
Breaking changes
2730
~~~~~~~~~~~~~~~~

xarray/plot/dataset_plot.py

Lines changed: 106 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,23 @@
11
from __future__ import annotations
22

33
import functools
4-
5-
import numpy as np
6-
import pandas as pd
4+
import inspect
5+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping
76

87
from ..core.alignment import broadcast
98
from .facetgrid import _easy_facetgrid
9+
from .plot import _PlotMethods
1010
from .utils import (
1111
_add_colorbar,
1212
_get_nice_quiver_magnitude,
1313
_infer_meta_data,
14-
_parse_size,
1514
_process_cmap_cbar_kwargs,
1615
get_axis,
1716
)
1817

19-
20-
def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None):
21-
22-
broadcast_keys = ["x", "y"]
23-
to_broadcast = [ds[x], ds[y]]
24-
if hue:
25-
to_broadcast.append(ds[hue])
26-
broadcast_keys.append("hue")
27-
if markersize:
28-
to_broadcast.append(ds[markersize])
29-
broadcast_keys.append("size")
30-
31-
broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast)))
32-
33-
data = {"x": broadcasted["x"], "y": broadcasted["y"], "hue": None, "sizes": None}
34-
35-
if hue:
36-
data["hue"] = broadcasted["hue"]
37-
38-
if markersize:
39-
size = broadcasted["size"]
40-
41-
if size_mapping is None:
42-
size_mapping = _parse_size(size, size_norm)
43-
44-
data["sizes"] = size.copy(
45-
data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape)
46-
)
47-
48-
return data
18+
if TYPE_CHECKING:
19+
from ..core.dataarray import DataArray
20+
from ..core.types import T_Dataset
4921

5022

5123
class _Dataset_PlotMethods:
@@ -352,67 +324,6 @@ def plotmethod(
352324
return newplotfunc
353325

354326

355-
@_dsplot
356-
def scatter(ds, x, y, ax, **kwargs):
357-
"""
358-
Scatter Dataset data variables against each other.
359-
360-
Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`.
361-
"""
362-
363-
if "add_colorbar" in kwargs or "add_legend" in kwargs:
364-
raise ValueError(
365-
"Dataset.plot.scatter does not accept "
366-
"'add_colorbar' or 'add_legend'. "
367-
"Use 'add_guide' instead."
368-
)
369-
370-
cmap_params = kwargs.pop("cmap_params")
371-
hue = kwargs.pop("hue")
372-
hue_style = kwargs.pop("hue_style")
373-
markersize = kwargs.pop("markersize", None)
374-
size_norm = kwargs.pop("size_norm", None)
375-
size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid
376-
377-
# Remove `u` and `v` so they don't get passed to `ax.scatter`
378-
kwargs.pop("u", None)
379-
kwargs.pop("v", None)
380-
381-
# need to infer size_mapping with full dataset
382-
data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping)
383-
384-
if hue_style == "discrete":
385-
primitive = []
386-
# use pd.unique instead of np.unique because that keeps the order of the labels,
387-
# which is important to keep them in sync with the ones used in
388-
# FacetGrid.add_legend
389-
for label in pd.unique(data["hue"].values.ravel()):
390-
mask = data["hue"] == label
391-
if data["sizes"] is not None:
392-
kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten())
393-
394-
primitive.append(
395-
ax.scatter(
396-
data["x"].where(mask, drop=True).values.flatten(),
397-
data["y"].where(mask, drop=True).values.flatten(),
398-
label=label,
399-
**kwargs,
400-
)
401-
)
402-
403-
elif hue is None or hue_style == "continuous":
404-
if data["sizes"] is not None:
405-
kwargs.update(s=data["sizes"].values.ravel())
406-
if data["hue"] is not None:
407-
kwargs.update(c=data["hue"].values.ravel())
408-
409-
primitive = ax.scatter(
410-
data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs
411-
)
412-
413-
return primitive
414-
415-
416327
@_dsplot
417328
def quiver(ds, x, y, ax, u, v, **kwargs):
418329
"""Quiver plot of Dataset variables.
@@ -497,3 +408,103 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
497408

498409
# Return .lines so colorbar creation works properly
499410
return hdl.lines
411+
412+
413+
def _attach_to_plot_class(plotfunc: Callable) -> None:
414+
"""
415+
Set the function to the plot class and add a common docstring.
416+
417+
Use this decorator when relying on DataArray.plot methods for
418+
creating the Dataset plot.
419+
420+
TODO: Reduce code duplication.
421+
422+
* The goal is to reduce code duplication by moving all Dataset
423+
specific plots to the DataArray side and use this thin wrapper to
424+
handle the conversion between Dataset and DataArray.
425+
* Improve docstring handling, maybe reword the DataArray versions to
426+
explain Datasets better.
427+
* Consider automatically adding all _PlotMethods to
428+
_Dataset_PlotMethods.
429+
430+
Parameters
431+
----------
432+
plotfunc : function
433+
Function that returns a finished plot primitive.
434+
"""
435+
# Build on the original docstring:
436+
original_doc = getattr(_PlotMethods, plotfunc.__name__, object)
437+
commondoc = original_doc.__doc__
438+
if commondoc is not None:
439+
doc_warning = (
440+
f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}."
441+
" Some inconsistencies may exist."
442+
)
443+
# Add indentation so it matches the original doc:
444+
commondoc = f"\n\n {doc_warning}\n\n {commondoc}"
445+
else:
446+
commondoc = ""
447+
plotfunc.__doc__ = (
448+
f" {plotfunc.__doc__}\n\n"
449+
" The `y` DataArray will be used as base,"
450+
" any other variables are added as coords.\n\n"
451+
f"{commondoc}"
452+
)
453+
454+
@functools.wraps(plotfunc)
455+
def plotmethod(self, *args, **kwargs):
456+
return plotfunc(self._ds, *args, **kwargs)
457+
458+
# Add to class _PlotMethods
459+
setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod)
460+
461+
462+
def _normalize_args(plotmethod: str, args, kwargs) -> dict[str, Any]:
463+
from ..core.dataarray import DataArray
464+
465+
# Determine positional arguments keyword by inspecting the
466+
# signature of the plotmethod:
467+
locals_ = dict(
468+
inspect.signature(getattr(DataArray().plot, plotmethod))
469+
.bind(*args, **kwargs)
470+
.arguments.items()
471+
)
472+
locals_.update(locals_.pop("kwargs", {}))
473+
474+
return locals_
475+
476+
477+
def _temp_dataarray(ds: T_Dataset, y: Hashable, locals_: Mapping) -> DataArray:
478+
"""Create a temporary datarray with extra coords."""
479+
from ..core.dataarray import DataArray
480+
481+
# Base coords:
482+
coords = dict(ds.coords)
483+
484+
# Add extra coords to the DataArray from valid kwargs, if using all
485+
# kwargs there is a risk that we add unneccessary dataarrays as
486+
# coords straining RAM further for example:
487+
# ds.both and extend="both" would add ds.both to the coords:
488+
valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"}
489+
coord_kwargs = locals_.keys() & valid_coord_kwargs
490+
for k in coord_kwargs:
491+
key = locals_[k]
492+
if ds.data_vars.get(key) is not None:
493+
coords[key] = ds[key]
494+
495+
# The dataarray has to include all the dims. Broadcast to that shape
496+
# and add the additional coords:
497+
_y = ds[y].broadcast_like(ds)
498+
499+
return DataArray(_y, coords=coords)
500+
501+
502+
@_attach_to_plot_class
503+
def scatter(ds: T_Dataset, x: Hashable, y: Hashable, *args, **kwargs):
504+
"""Scatter plot Dataset data variables against each other."""
505+
plotmethod = "scatter"
506+
kwargs.update(x=x)
507+
locals_ = _normalize_args(plotmethod, args, kwargs)
508+
da = _temp_dataarray(ds, y, locals_)
509+
510+
return getattr(da.plot, plotmethod)(*locals_.pop("args", ()), **locals_)

0 commit comments

Comments
 (0)