Skip to content

Import nc_time_axis when needed #7276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Import ``nc_time_axis`` when needed (:issue:`7275`, :pull:`7276`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
64 changes: 34 additions & 30 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,8 @@ def _resolve_intervals_1dplot(
remove_drawstyle = False

# Convert intervals to double points
x_is_interval = _valid_other_type(xval, [pd.Interval])
y_is_interval = _valid_other_type(yval, [pd.Interval])
x_is_interval = _valid_other_type(xval, pd.Interval)
y_is_interval = _valid_other_type(yval, pd.Interval)
if x_is_interval and y_is_interval:
raise TypeError("Can't step plot intervals against intervals.")
elif x_is_interval:
Expand All @@ -628,10 +628,10 @@ def _resolve_intervals_1dplot(
else:

# Convert intervals to mid points and adjust labels
if _valid_other_type(xval, [pd.Interval]):
if _valid_other_type(xval, pd.Interval):
xval = _interval_to_mid_points(xval)
x_suffix = "_center"
if _valid_other_type(yval, [pd.Interval]):
if _valid_other_type(yval, pd.Interval):
yval = _interval_to_mid_points(yval)
y_suffix = "_center"

Expand All @@ -646,7 +646,7 @@ def _resolve_intervals_2dplot(val, func_name):
increases length by 1.
"""
label_extra = ""
if _valid_other_type(val, [pd.Interval]):
if _valid_other_type(val, pd.Interval):
if func_name == "pcolormesh":
val = _interval_to_bound_points(val)
else:
Expand All @@ -656,11 +656,13 @@ def _resolve_intervals_2dplot(val, func_name):
return val, label_extra


def _valid_other_type(x, types):
def _valid_other_type(
x: ArrayLike, types: type[object] | tuple[type[object], ...]
) -> bool:
"""
Do all elements of x have a type from types?
"""
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))
return all(isinstance(el, types) for el in np.ravel(x))


def _valid_numpy_subdtype(x, numpy_types):
Expand All @@ -675,47 +677,49 @@ def _valid_numpy_subdtype(x, numpy_types):
return any(np.issubdtype(x.dtype, t) for t in numpy_types)


def _ensure_plottable(*args):
def _ensure_plottable(*args) -> None:
"""
Raise exception if there is anything in args that can't be plotted on an
axis by matplotlib.
"""
numpy_types = [
numpy_types: tuple[type[object], ...] = (
np.floating,
np.integer,
np.timedelta64,
np.datetime64,
np.bool_,
np.str_,
]
other_types = [datetime]
if cftime is not None:
cftime_datetime_types = [cftime.datetime]
other_types = other_types + cftime_datetime_types
else:
cftime_datetime_types = []
)
other_types: tuple[type[object], ...] = (datetime,)
cftime_datetime_types: tuple[type[object], ...] = (
() if cftime is None else (cftime.datetime,)
)
other_types += cftime_datetime_types

for x in args:
if not (
_valid_numpy_subdtype(np.array(x), numpy_types)
or _valid_other_type(np.array(x), other_types)
_valid_numpy_subdtype(np.asarray(x), numpy_types)
or _valid_other_type(np.asarray(x), other_types)
):
raise TypeError(
"Plotting requires coordinates to be numeric, boolean, "
"or dates of type numpy.datetime64, "
"datetime.datetime, cftime.datetime or "
f"pandas.Interval. Received data of type {np.array(x).dtype} instead."
)
if (
_valid_other_type(np.array(x), cftime_datetime_types)
and not nc_time_axis_available
):
raise ImportError(
"Plotting of arrays of cftime.datetime "
"objects or arrays indexed by "
"cftime.datetime objects requires the "
"optional `nc-time-axis` (v1.2.0 or later) "
"package."
f"pandas.Interval. Received data of type {np.asarray(x).dtype} instead."
)
if _valid_other_type(np.asarray(x), cftime_datetime_types):
if nc_time_axis_available:
# Register cftime datetypes to matplotlib.units.registry,
# otherwise matplotlib will raise an error:
import nc_time_axis # noqa: F401
else:
raise ImportError(
"Plotting of arrays of cftime.datetime "
"objects or arrays indexed by "
"cftime.datetime objects requires the "
"optional `nc-time-axis` (v1.2.0 or later) "
"package."
)


def _is_numeric(arr):
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _importorskip(
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
has_nc_time_axis, requires_nc_time_axis = _importorskip("nc_time_axis")
has_rasterio, requires_rasterio = _importorskip("rasterio")
has_zarr, requires_zarr = _importorskip("zarr")
has_fsspec, requires_fsspec = _importorskip("fsspec")
Expand Down
10 changes: 6 additions & 4 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import xarray as xr
import xarray.plot as xplt
from xarray import DataArray, Dataset
from xarray.core.utils import module_available
from xarray.plot.dataarray_plot import _infer_interval_breaks
from xarray.plot.dataset_plot import _infer_meta_data
from xarray.plot.utils import (
Expand All @@ -29,14 +30,15 @@
from . import (
assert_array_equal,
assert_equal,
has_nc_time_axis,
requires_cartopy,
requires_cftime,
requires_matplotlib,
requires_nc_time_axis,
requires_seaborn,
)

# this should not be imported to test if the automatic lazy import works
has_nc_time_axis = module_available("nc_time_axis")

# import mpl and change the backend before other mpl imports
try:
import matplotlib as mpl
Expand Down Expand Up @@ -2823,8 +2825,8 @@ def test_datetime_plot2d(self) -> None:


@pytest.mark.filterwarnings("ignore:setting an array element with a sequence")
@requires_nc_time_axis
@requires_cftime
@pytest.mark.skipif(not has_nc_time_axis, reason="nc_time_axis is not installed")
class TestCFDatetimePlot(PlotTestCase):
@pytest.fixture(autouse=True)
def setUp(self) -> None:
Expand Down Expand Up @@ -3206,7 +3208,7 @@ def test_plot_empty_raises(val: list | float, method: str) -> None:


@requires_matplotlib
def test_facetgrid_axes_raises_deprecation_warning():
def test_facetgrid_axes_raises_deprecation_warning() -> None:
with pytest.warns(
DeprecationWarning,
match=(
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def test_lazy_import() -> None:
"scipy",
"zarr",
"matplotlib",
"nc_time_axis",
"flox",
# "dask", # TODO: backends.locks is not lazy yet :(
"dask.array",
Expand Down