diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 324e9126..97ebd3e3 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple import matplotlib.colors as mcolor import napari @@ -22,21 +22,21 @@ class ScatterBaseWidget(NapariMPLWidget): # the scatter is plotted as a 2dhist _threshold_to_switch_to_histogram = 500 - def __init__( - self, - napari_viewer: napari.viewer.Viewer, - ): + def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) self.axes = self.canvas.figure.subplots() self.update_layers(None) def clear(self) -> None: + """ + Clear the axes. + """ self.axes.clear() def draw(self) -> None: """ - Clear the axes and scatter the currently selected layers. + Scatter the currently selected layers. """ data, x_axis_name, y_axis_name = self._get_data() @@ -86,14 +86,6 @@ class ScatterWidget(ScatterBaseWidget): n_layers_input = 2 - def __init__( - self, - napari_viewer: napari.viewer.Viewer, - ): - super().__init__( - napari_viewer, - ) - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: """Get the plot data. @@ -116,42 +108,34 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: class FeaturesScatterWidget(ScatterBaseWidget): n_layers_input = 1 - def __init__( - self, - napari_viewer: napari.viewer.Viewer, - key_selection_gui: bool = True, - ): - self._key_selection_widget = None - super().__init__( - napari_viewer, + def __init__(self, napari_viewer: napari.viewer.Viewer): + super().__init__(napari_viewer) + self._key_selection_widget = magicgui( + self._set_axis_keys, + x_axis_key={"choices": self._get_valid_axis_keys}, + y_axis_key={"choices": self._get_valid_axis_keys}, + call_button="plot", ) - if key_selection_gui is True: - self._key_selection_widget = magicgui( - self._set_axis_keys, - x_axis_key={"choices": self._get_valid_axis_keys}, - y_axis_key={"choices": self._get_valid_axis_keys}, - call_button="plot", - ) - self.layout().addWidget(self._key_selection_widget.native) + self.layout().addWidget(self._key_selection_widget.native) @property - def x_axis_key(self) -> Union[None, str]: + def x_axis_key(self) -> Optional[str]: """Key to access x axis data from the FeaturesTable""" return self._x_axis_key @x_axis_key.setter - def x_axis_key(self, key: Union[None, str]): + def x_axis_key(self, key: Optional[str]): self._x_axis_key = key self._draw() @property - def y_axis_key(self) -> Union[None, str]: + def y_axis_key(self) -> Optional[str]: """Key to access y axis data from the FeaturesTable""" return self._y_axis_key @y_axis_key.setter - def y_axis_key(self, key: Union[None, str]): + def y_axis_key(self, key: Optional[str]): self._y_axis_key = key self._draw() @@ -214,10 +198,11 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: return data, x_axis_name, y_axis_name def _on_update_layers(self) -> None: - """This is called when the layer selection changes - by self.update_layers(). """ - if self._key_selection_widget is not None: + This is called when the layer selection changes by + ``self.update_layers()``. + """ + if hasattr(self, "_key_selection_widget"): self._key_selection_widget.reset_choices() # reset the axis keys