Skip to content

Tidy up scatter code #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 22 additions & 37 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import List, Optional, Tuple

import matplotlib.colors as mcolor
import napari
Expand All @@ -22,21 +22,21 @@ class ScatterBaseWidget(NapariMPLWidget):
# the scatter is plotted as a 2dhist
_threshold_to_switch_to_histogram = 500

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
):
def __init__(self, napari_viewer: napari.viewer.Viewer):
super().__init__(napari_viewer)

self.axes = self.canvas.figure.subplots()
self.update_layers(None)

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.clear()

def draw(self) -> None:
"""
Clear the axes and scatter the currently selected layers.
Scatter the currently selected layers.
"""
data, x_axis_name, y_axis_name = self._get_data()

Expand Down Expand Up @@ -86,14 +86,6 @@ class ScatterWidget(ScatterBaseWidget):

n_layers_input = 2

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
):
super().__init__(
napari_viewer,
)

def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
"""Get the plot data.

Expand All @@ -116,42 +108,34 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
class FeaturesScatterWidget(ScatterBaseWidget):
n_layers_input = 1

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
key_selection_gui: bool = True,
):
self._key_selection_widget = None
super().__init__(
napari_viewer,
)
def __init__(self, napari_viewer: napari.viewer.Viewer):
super().__init__(napari_viewer)

if key_selection_gui is True:
self._key_selection_widget = magicgui(
self._set_axis_keys,
x_axis_key={"choices": self._get_valid_axis_keys},
y_axis_key={"choices": self._get_valid_axis_keys},
call_button="plot",
)
self.layout().addWidget(self._key_selection_widget.native)
self._key_selection_widget = magicgui(
self._set_axis_keys,
x_axis_key={"choices": self._get_valid_axis_keys},
y_axis_key={"choices": self._get_valid_axis_keys},
call_button="plot",
)
self.layout().addWidget(self._key_selection_widget.native)

@property
def x_axis_key(self) -> Union[None, str]:
def x_axis_key(self) -> Optional[str]:
"""Key to access x axis data from the FeaturesTable"""
return self._x_axis_key

@x_axis_key.setter
def x_axis_key(self, key: Union[None, str]):
def x_axis_key(self, key: Optional[str]):
self._x_axis_key = key
self._draw()

@property
def y_axis_key(self) -> Union[None, str]:
def y_axis_key(self) -> Optional[str]:
"""Key to access y axis data from the FeaturesTable"""
return self._y_axis_key

@y_axis_key.setter
def y_axis_key(self, key: Union[None, str]):
def y_axis_key(self, key: Optional[str]):
self._y_axis_key = key
self._draw()

Expand Down Expand Up @@ -214,8 +198,9 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
return data, x_axis_name, y_axis_name

def _on_update_layers(self) -> None:
"""This is called when the layer selection changes
by self.update_layers().
"""
This is called when the layer selection changes by
``self.update_layers()``.
"""
if self._key_selection_widget is not None:
self._key_selection_widget.reset_choices()
Expand Down