Skip to content

TYP: plotting._matplotlib #47311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,6 @@ def closed(self) -> bool:

# quantile interpolation
QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"]

# plotting
PlottingOrientation = Literal["horizontal", "vertical"]
5 changes: 3 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5727,7 +5727,7 @@ def _check_inplace_setting(self, value) -> bool_t:
return True

@final
def _get_numeric_data(self):
def _get_numeric_data(self: NDFrameT) -> NDFrameT:
return self._constructor(self._mgr.get_numeric_data()).__finalize__(self)

@final
Expand Down Expand Up @@ -10954,7 +10954,8 @@ def mad(

data = self._get_numeric_data()
if axis == 0:
demeaned = data - data.mean(axis=0)
# error: Unsupported operand types for - ("NDFrame" and "float")
demeaned = data - data.mean(axis=0) # type: ignore[operator]
else:
demeaned = data.sub(data.mean(axis=1), axis=0)
return np.abs(demeaned).mean(axis=axis, skipna=skipna)
Expand Down
10 changes: 5 additions & 5 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,11 +1879,11 @@ def _get_plot_backend(backend: str | None = None):
-----
Modifies `_backends` with imported backend as a side effect.
"""
backend = backend or get_option("plotting.backend")
backend_str: str = backend or get_option("plotting.backend")

if backend in _backends:
return _backends[backend]
if backend_str in _backends:
return _backends[backend_str]

module = _load_backend(backend)
_backends[backend] = module
module = _load_backend(backend_str)
_backends[backend_str] = module
return module
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import (
TYPE_CHECKING,
Literal,
NamedTuple,
)
import warnings
Expand Down Expand Up @@ -34,7 +35,10 @@


class BoxPlot(LinePlot):
_kind = "box"
@property
def _kind(self) -> Literal["box"]:
return "box"

_layout_type = "horizontal"

_valid_return_types = (None, "axes", "dict", "both")
Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ def _daily_finder(vmin, vmax, freq: BaseOffset):
Period(ordinal=int(vmin), freq=freq),
Period(ordinal=int(vmax), freq=freq),
)
assert isinstance(vmin, Period)
assert isinstance(vmax, Period)
span = vmax.ordinal - vmin.ordinal + 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can safely assert as NaT has no attribute called ordinal

dates_ = period_range(start=vmin, end=vmax, freq=freq)
# Initialize the output
Expand Down Expand Up @@ -1073,7 +1075,9 @@ def __call__(self, x, pos=0) -> str:
fmt = self.formatdict.pop(x, "")
if isinstance(fmt, np.bytes_):
fmt = fmt.decode("utf-8")
return Period(ordinal=int(x), freq=self.freq).strftime(fmt)
period = Period(ordinal=int(x), freq=self.freq)
assert isinstance(period, Period)
return period.strftime(fmt)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here: NaT has no strftime



class TimeSeries_TimedeltaFormatter(Formatter):
Expand Down
80 changes: 60 additions & 20 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from __future__ import annotations

from abc import (
ABC,
abstractmethod,
)
from typing import (
TYPE_CHECKING,
Hashable,
Iterable,
Literal,
Sequence,
)
import warnings

from matplotlib.artist import Artist
import numpy as np

from pandas._typing import IndexLabel
from pandas._typing import (
IndexLabel,
PlottingOrientation,
)
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -78,7 +86,7 @@ def _color_in_style(style: str) -> bool:
return not set(BASE_COLORS).isdisjoint(style)


class MPLPlot:
class MPLPlot(ABC):
"""
Base class for assembling a pandas plot using matplotlib

Expand All @@ -89,13 +97,17 @@ class MPLPlot:
"""

@property
def _kind(self):
@abstractmethod
def _kind(self) -> str:
"""Specify kind str. Must be overridden in child class"""
raise NotImplementedError

_layout_type = "vertical"
_default_rot = 0
orientation: str | None = None

@property
def orientation(self) -> str | None:
return None

axes: np.ndarray # of Axes objects

Expand Down Expand Up @@ -843,7 +855,9 @@ def _get_xticks(self, convert_period: bool = False):

@classmethod
@register_pandas_matplotlib_converters
def _plot(cls, ax: Axes, x, y, style=None, is_errorbar: bool = False, **kwds):
def _plot(
cls, ax: Axes, x, y: np.ndarray, style=None, is_errorbar: bool = False, **kwds
):
mask = isna(y)
if mask.any():
y = np.ma.array(y)
Expand Down Expand Up @@ -1101,7 +1115,7 @@ def _get_axes_layout(self) -> tuple[int, int]:
return (len(y_set), len(x_set))


class PlanePlot(MPLPlot):
class PlanePlot(MPLPlot, ABC):
"""
Abstract class for plotting on plane, currently scatter and hexbin.
"""
Expand Down Expand Up @@ -1159,7 +1173,9 @@ def _plot_colorbar(self, ax: Axes, **kwds):


class ScatterPlot(PlanePlot):
_kind = "scatter"
@property
def _kind(self) -> Literal["scatter"]:
return "scatter"

def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
if s is None:
Expand Down Expand Up @@ -1247,7 +1263,9 @@ def _make_plot(self):


class HexBinPlot(PlanePlot):
_kind = "hexbin"
@property
def _kind(self) -> Literal["hexbin"]:
return "hexbin"

def __init__(self, data, x, y, C=None, **kwargs) -> None:
super().__init__(data, x, y, **kwargs)
Expand Down Expand Up @@ -1277,9 +1295,15 @@ def _make_legend(self):


class LinePlot(MPLPlot):
_kind = "line"
_default_rot = 0
orientation = "vertical"

@property
def orientation(self) -> PlottingOrientation:
return "vertical"

@property
def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed so that the sub-classes are compatible with LinePlot

return "line"

def __init__(self, data, **kwargs) -> None:
from pandas.plotting import plot_params
Expand Down Expand Up @@ -1363,8 +1387,7 @@ def _plot( # type: ignore[override]
cls._update_stacker(ax, stacking_id, y)
return lines

@classmethod
def _ts_plot(cls, ax: Axes, x, data, style=None, **kwds):
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
# column_num must be in kwds for stacking purpose
Expand All @@ -1377,9 +1400,9 @@ def _ts_plot(cls, ax: Axes, x, data, style=None, **kwds):
decorate_axes(ax.left_ax, freq, kwds)
if hasattr(ax, "right_ax"):
decorate_axes(ax.right_ax, freq, kwds)
ax._plot_data.append((data, cls._kind, kwds))
ax._plot_data.append((data, self._kind, kwds))

lines = cls._plot(ax, data.index, data.values, style=style, **kwds)
lines = self._plot(ax, data.index, data.values, style=style, **kwds)
# set date formatter, locators and rescale limits
format_dateaxis(ax, ax.freq, data.index)
return lines
Expand Down Expand Up @@ -1471,7 +1494,9 @@ def get_label(i):


class AreaPlot(LinePlot):
_kind = "area"
@property
def _kind(self) -> Literal["area"]:
return "area"

def __init__(self, data, **kwargs) -> None:
kwargs.setdefault("stacked", True)
Expand Down Expand Up @@ -1544,9 +1569,15 @@ def _post_plot_logic(self, ax: Axes, data):


class BarPlot(MPLPlot):
_kind = "bar"
@property
def _kind(self) -> Literal["bar", "barh"]:
return "bar"

_default_rot = 90
orientation = "vertical"

@property
def orientation(self) -> PlottingOrientation:
return "vertical"

def __init__(self, data, **kwargs) -> None:
# we have to treat a series differently than a
Expand Down Expand Up @@ -1698,9 +1729,15 @@ def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge):


class BarhPlot(BarPlot):
_kind = "barh"
@property
def _kind(self) -> Literal["barh"]:
return "barh"

_default_rot = 0
orientation = "horizontal"

@property
def orientation(self) -> Literal["horizontal"]:
return "horizontal"

@property
def _start_base(self):
Expand All @@ -1727,7 +1764,10 @@ def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge):


class PiePlot(MPLPlot):
_kind = "pie"
@property
def _kind(self) -> Literal["pie"]:
return "pie"

_layout_type = "horizontal"

def __init__(self, data, kind=None, **kwargs) -> None:
Expand Down
26 changes: 19 additions & 7 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas._typing import PlottingOrientation

from pandas.core.dtypes.common import (
is_integer,
is_list_like,
Expand Down Expand Up @@ -40,7 +45,9 @@


class HistPlot(LinePlot):
_kind = "hist"
@property
def _kind(self) -> Literal["hist", "kde"]:
return "hist"

def __init__(self, data, bins=10, bottom=0, **kwargs) -> None:
self.bins = bins # use mpl default
Expand All @@ -64,8 +71,8 @@ def _args_adjust(self):

def _calculate_bins(self, data: DataFrame) -> np.ndarray:
"""Calculate bins given data"""
values = data._convert(datetime=True)._get_numeric_data()
values = np.ravel(values)
nd_values = data._convert(datetime=True)._get_numeric_data()
values = np.ravel(nd_values)
values = values[~isna(values)]

hist, bins = np.histogram(
Expand Down Expand Up @@ -159,16 +166,21 @@ def _post_plot_logic(self, ax: Axes, data):
ax.set_ylabel("Frequency")

@property
def orientation(self):
def orientation(self) -> PlottingOrientation:
if self.kwds.get("orientation", None) == "horizontal":
return "horizontal"
else:
return "vertical"


class KdePlot(HistPlot):
_kind = "kde"
orientation = "vertical"
@property
def _kind(self) -> Literal["kde"]:
return "kde"

@property
def orientation(self) -> Literal["vertical"]:
return "vertical"

def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
MPLPlot.__init__(self, data, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def _get_colors_from_colormap(
num_colors: int,
) -> list[Color]:
"""Get colors from colormap."""
colormap = _get_cmap_instance(colormap)
return [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
cmap = _get_cmap_instance(colormap)
return [cmap(num) for num in np.linspace(0, 1, num=num_colors)]


def _get_cmap_instance(colormap: str | Colormap) -> Colormap:
Expand Down
10 changes: 6 additions & 4 deletions pandas/plotting/_matplotlib/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from datetime import timedelta
import functools
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -185,11 +186,10 @@ def _get_ax_freq(ax: Axes):
return ax_freq


def _get_period_alias(freq) -> str | None:
def _get_period_alias(freq: timedelta | BaseOffset | str) -> str | None:
freqstr = to_offset(freq).rule_code

freq = get_period_alias(freqstr)
return freq
return get_period_alias(freqstr)


def _get_freq(ax: Axes, series: Series):
Expand Down Expand Up @@ -235,7 +235,9 @@ def use_dynamic_x(ax: Axes, data: DataFrame | Series) -> bool:
x = data.index
if base <= FreqGroup.FR_DAY.value:
return x[:1].is_normalized
return Period(x[0], freq_str).to_timestamp().tz_localize(x.tz) == x[0]
period = Period(x[0], freq_str)
assert isinstance(period, Period)
return period.to_timestamp().tz_localize(x.tz) == x[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NaT has no to_timestamp

return True


Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def table(
return table


def _get_layout(nplots: int, layout=None, layout_type: str = "box") -> tuple[int, int]:
def _get_layout(
nplots: int,
layout: tuple[int, int] | None = None,
layout_type: str = "box",
) -> tuple[int, int]:
if layout is not None:
if not isinstance(layout, (tuple, list)) or len(layout) != 2:
raise ValueError("Layout must be a tuple of (rows, columns)")
Expand Down
Loading