|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import functools
|
4 |
| - |
5 |
| -import numpy as np |
6 |
| -import pandas as pd |
| 4 | +import inspect |
| 5 | +from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping |
7 | 6 |
|
8 | 7 | from ..core.alignment import broadcast
|
9 | 8 | from .facetgrid import _easy_facetgrid
|
| 9 | +from .plot import _PlotMethods |
10 | 10 | from .utils import (
|
11 | 11 | _add_colorbar,
|
12 | 12 | _get_nice_quiver_magnitude,
|
13 | 13 | _infer_meta_data,
|
14 |
| - _parse_size, |
15 | 14 | _process_cmap_cbar_kwargs,
|
16 | 15 | get_axis,
|
17 | 16 | )
|
18 | 17 |
|
19 |
| - |
20 |
| -def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None): |
21 |
| - |
22 |
| - broadcast_keys = ["x", "y"] |
23 |
| - to_broadcast = [ds[x], ds[y]] |
24 |
| - if hue: |
25 |
| - to_broadcast.append(ds[hue]) |
26 |
| - broadcast_keys.append("hue") |
27 |
| - if markersize: |
28 |
| - to_broadcast.append(ds[markersize]) |
29 |
| - broadcast_keys.append("size") |
30 |
| - |
31 |
| - broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast))) |
32 |
| - |
33 |
| - data = {"x": broadcasted["x"], "y": broadcasted["y"], "hue": None, "sizes": None} |
34 |
| - |
35 |
| - if hue: |
36 |
| - data["hue"] = broadcasted["hue"] |
37 |
| - |
38 |
| - if markersize: |
39 |
| - size = broadcasted["size"] |
40 |
| - |
41 |
| - if size_mapping is None: |
42 |
| - size_mapping = _parse_size(size, size_norm) |
43 |
| - |
44 |
| - data["sizes"] = size.copy( |
45 |
| - data=np.reshape(size_mapping.loc[size.values.ravel()].values, size.shape) |
46 |
| - ) |
47 |
| - |
48 |
| - return data |
| 18 | +if TYPE_CHECKING: |
| 19 | + from ..core.dataarray import DataArray |
| 20 | + from ..core.types import T_Dataset |
49 | 21 |
|
50 | 22 |
|
51 | 23 | class _Dataset_PlotMethods:
|
@@ -352,67 +324,6 @@ def plotmethod(
|
352 | 324 | return newplotfunc
|
353 | 325 |
|
354 | 326 |
|
355 |
| -@_dsplot |
356 |
| -def scatter(ds, x, y, ax, **kwargs): |
357 |
| - """ |
358 |
| - Scatter Dataset data variables against each other. |
359 |
| -
|
360 |
| - Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`. |
361 |
| - """ |
362 |
| - |
363 |
| - if "add_colorbar" in kwargs or "add_legend" in kwargs: |
364 |
| - raise ValueError( |
365 |
| - "Dataset.plot.scatter does not accept " |
366 |
| - "'add_colorbar' or 'add_legend'. " |
367 |
| - "Use 'add_guide' instead." |
368 |
| - ) |
369 |
| - |
370 |
| - cmap_params = kwargs.pop("cmap_params") |
371 |
| - hue = kwargs.pop("hue") |
372 |
| - hue_style = kwargs.pop("hue_style") |
373 |
| - markersize = kwargs.pop("markersize", None) |
374 |
| - size_norm = kwargs.pop("size_norm", None) |
375 |
| - size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid |
376 |
| - |
377 |
| - # Remove `u` and `v` so they don't get passed to `ax.scatter` |
378 |
| - kwargs.pop("u", None) |
379 |
| - kwargs.pop("v", None) |
380 |
| - |
381 |
| - # need to infer size_mapping with full dataset |
382 |
| - data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping) |
383 |
| - |
384 |
| - if hue_style == "discrete": |
385 |
| - primitive = [] |
386 |
| - # use pd.unique instead of np.unique because that keeps the order of the labels, |
387 |
| - # which is important to keep them in sync with the ones used in |
388 |
| - # FacetGrid.add_legend |
389 |
| - for label in pd.unique(data["hue"].values.ravel()): |
390 |
| - mask = data["hue"] == label |
391 |
| - if data["sizes"] is not None: |
392 |
| - kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten()) |
393 |
| - |
394 |
| - primitive.append( |
395 |
| - ax.scatter( |
396 |
| - data["x"].where(mask, drop=True).values.flatten(), |
397 |
| - data["y"].where(mask, drop=True).values.flatten(), |
398 |
| - label=label, |
399 |
| - **kwargs, |
400 |
| - ) |
401 |
| - ) |
402 |
| - |
403 |
| - elif hue is None or hue_style == "continuous": |
404 |
| - if data["sizes"] is not None: |
405 |
| - kwargs.update(s=data["sizes"].values.ravel()) |
406 |
| - if data["hue"] is not None: |
407 |
| - kwargs.update(c=data["hue"].values.ravel()) |
408 |
| - |
409 |
| - primitive = ax.scatter( |
410 |
| - data["x"].values.ravel(), data["y"].values.ravel(), **cmap_params, **kwargs |
411 |
| - ) |
412 |
| - |
413 |
| - return primitive |
414 |
| - |
415 |
| - |
416 | 327 | @_dsplot
|
417 | 328 | def quiver(ds, x, y, ax, u, v, **kwargs):
|
418 | 329 | """Quiver plot of Dataset variables.
|
@@ -497,3 +408,103 @@ def streamplot(ds, x, y, ax, u, v, **kwargs):
|
497 | 408 |
|
498 | 409 | # Return .lines so colorbar creation works properly
|
499 | 410 | return hdl.lines
|
| 411 | + |
| 412 | + |
| 413 | +def _attach_to_plot_class(plotfunc: Callable) -> None: |
| 414 | + """ |
| 415 | + Set the function to the plot class and add a common docstring. |
| 416 | +
|
| 417 | + Use this decorator when relying on DataArray.plot methods for |
| 418 | + creating the Dataset plot. |
| 419 | +
|
| 420 | + TODO: Reduce code duplication. |
| 421 | +
|
| 422 | + * The goal is to reduce code duplication by moving all Dataset |
| 423 | + specific plots to the DataArray side and use this thin wrapper to |
| 424 | + handle the conversion between Dataset and DataArray. |
| 425 | + * Improve docstring handling, maybe reword the DataArray versions to |
| 426 | + explain Datasets better. |
| 427 | + * Consider automatically adding all _PlotMethods to |
| 428 | + _Dataset_PlotMethods. |
| 429 | +
|
| 430 | + Parameters |
| 431 | + ---------- |
| 432 | + plotfunc : function |
| 433 | + Function that returns a finished plot primitive. |
| 434 | + """ |
| 435 | + # Build on the original docstring: |
| 436 | + original_doc = getattr(_PlotMethods, plotfunc.__name__, object) |
| 437 | + commondoc = original_doc.__doc__ |
| 438 | + if commondoc is not None: |
| 439 | + doc_warning = ( |
| 440 | + f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}." |
| 441 | + " Some inconsistencies may exist." |
| 442 | + ) |
| 443 | + # Add indentation so it matches the original doc: |
| 444 | + commondoc = f"\n\n {doc_warning}\n\n {commondoc}" |
| 445 | + else: |
| 446 | + commondoc = "" |
| 447 | + plotfunc.__doc__ = ( |
| 448 | + f" {plotfunc.__doc__}\n\n" |
| 449 | + " The `y` DataArray will be used as base," |
| 450 | + " any other variables are added as coords.\n\n" |
| 451 | + f"{commondoc}" |
| 452 | + ) |
| 453 | + |
| 454 | + @functools.wraps(plotfunc) |
| 455 | + def plotmethod(self, *args, **kwargs): |
| 456 | + return plotfunc(self._ds, *args, **kwargs) |
| 457 | + |
| 458 | + # Add to class _PlotMethods |
| 459 | + setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) |
| 460 | + |
| 461 | + |
| 462 | +def _normalize_args(plotmethod: str, args, kwargs) -> dict[str, Any]: |
| 463 | + from ..core.dataarray import DataArray |
| 464 | + |
| 465 | + # Determine positional arguments keyword by inspecting the |
| 466 | + # signature of the plotmethod: |
| 467 | + locals_ = dict( |
| 468 | + inspect.signature(getattr(DataArray().plot, plotmethod)) |
| 469 | + .bind(*args, **kwargs) |
| 470 | + .arguments.items() |
| 471 | + ) |
| 472 | + locals_.update(locals_.pop("kwargs", {})) |
| 473 | + |
| 474 | + return locals_ |
| 475 | + |
| 476 | + |
| 477 | +def _temp_dataarray(ds: T_Dataset, y: Hashable, locals_: Mapping) -> DataArray: |
| 478 | + """Create a temporary datarray with extra coords.""" |
| 479 | + from ..core.dataarray import DataArray |
| 480 | + |
| 481 | + # Base coords: |
| 482 | + coords = dict(ds.coords) |
| 483 | + |
| 484 | + # Add extra coords to the DataArray from valid kwargs, if using all |
| 485 | + # kwargs there is a risk that we add unneccessary dataarrays as |
| 486 | + # coords straining RAM further for example: |
| 487 | + # ds.both and extend="both" would add ds.both to the coords: |
| 488 | + valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"} |
| 489 | + coord_kwargs = locals_.keys() & valid_coord_kwargs |
| 490 | + for k in coord_kwargs: |
| 491 | + key = locals_[k] |
| 492 | + if ds.data_vars.get(key) is not None: |
| 493 | + coords[key] = ds[key] |
| 494 | + |
| 495 | + # The dataarray has to include all the dims. Broadcast to that shape |
| 496 | + # and add the additional coords: |
| 497 | + _y = ds[y].broadcast_like(ds) |
| 498 | + |
| 499 | + return DataArray(_y, coords=coords) |
| 500 | + |
| 501 | + |
| 502 | +@_attach_to_plot_class |
| 503 | +def scatter(ds: T_Dataset, x: Hashable, y: Hashable, *args, **kwargs): |
| 504 | + """Scatter plot Dataset data variables against each other.""" |
| 505 | + plotmethod = "scatter" |
| 506 | + kwargs.update(x=x) |
| 507 | + locals_ = _normalize_args(plotmethod, args, kwargs) |
| 508 | + da = _temp_dataarray(ds, y, locals_) |
| 509 | + |
| 510 | + return getattr(da.plot, plotmethod)(*locals_.pop("args", ()), **locals_) |
0 commit comments