diff --git a/examples/slice.py b/examples/slice.py new file mode 100644 index 00000000..3e43443e --- /dev/null +++ b/examples/slice.py @@ -0,0 +1,15 @@ +""" +1D slices +========= +""" +import napari + +viewer = napari.Viewer() +viewer.open_sample("napari", "kidney") + +viewer.window.add_plugin_dock_widget( + plugin_name="napari-matplotlib", widget_name="1D slice" +) + +if __name__ == "__main__": + napari.run() diff --git a/src/napari_matplotlib/__init__.py b/src/napari_matplotlib/__init__.py index 7e8ccf69..2d0e7e71 100644 --- a/src/napari_matplotlib/__init__.py +++ b/src/napari_matplotlib/__init__.py @@ -6,3 +6,4 @@ from .histogram import * # NoQA from .scatter import * # NoQA +from .slice import * # NoQA diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 6bfbd093..8ef507d1 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -83,7 +83,7 @@ def setup_callbacks(self) -> None: def update_layers(self, event: napari.utils.events.Event) -> None: """ - Update the currently selected layers and re-draw. + Update the layers attribute with currently selected layers and re-draw. """ self.layers = list(self.viewer.layers.selection) self._draw() diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index 3ff66090..cd585879 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -10,9 +10,16 @@ contributions: python_name: napari_matplotlib:ScatterWidget title: Make a scatter plot + - id: napari-matplotlib.slice + python_name: napari_matplotlib:SliceWidget + title: Plot a 1D slice + widgets: - command: napari-matplotlib.histogram display_name: Histogram - command: napari-matplotlib.scatter display_name: Scatter + + - command: napari-matplotlib.slice + display_name: 1D slice diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py new file mode 100644 index 00000000..1d6407f5 --- /dev/null +++ b/src/napari_matplotlib/slice.py @@ -0,0 +1,118 @@ +from typing import Dict, Tuple + +import napari +import numpy as np +from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox + +from napari_matplotlib.base import NapariMPLWidget + +__all__ = ["SliceWidget"] + +_dims_sel = ["x", "y"] +_dims = ["x", "y", "z"] + + +class SliceWidget(NapariMPLWidget): + """ + Plot a 1D slice along a given dimension. + """ + + n_layers_input = 1 + + def __init__(self, napari_viewer: napari.viewer.Viewer): + # Setup figure/axes + super().__init__(napari_viewer) + self.axes = self.canvas.figure.subplots() + + button_layout = QHBoxLayout() + self.layout().addLayout(button_layout) + + self.dim_selector = QComboBox() + button_layout.addWidget(QLabel("Slice axis:")) + button_layout.addWidget(self.dim_selector) + self.dim_selector.addItems(_dims) + + self.slice_selectors = {} + for d in _dims_sel: + self.slice_selectors[d] = QSpinBox() + button_layout.addWidget(QLabel(f"{d}:")) + button_layout.addWidget(self.slice_selectors[d]) + + # Setup callbacks + # Re-draw when any of the combon/spin boxes are updated + self.dim_selector.currentTextChanged.connect(self._draw) + for d in _dims_sel: + self.slice_selectors[d].textChanged.connect(self._draw) + + self.update_layers(None) + + @property + def layer(self): + return self.layers[0] + + @property + def current_dim(self) -> str: + """ + Currently selected slice dimension. + """ + return self.dim_selector.currentText() + + @property + def current_dim_index(self) -> int: + """ + Currently selected slice dimension index. + """ + # Note the reversed list because in napari the z-axis is the first + # numpy axis + return _dims[::-1].index(self.current_dim) + + @property + def selector_values(self) -> Dict[str, int]: + return {d: self.slice_selectors[d].value() for d in _dims_sel} + + def update_slice_selectors(self) -> None: + """ + Update range and enabled status of the slice selectors, and the value + of the z slice selector. + """ + # Update min/max + for i, dim in enumerate(_dims_sel): + self.slice_selectors[dim].setRange(0, self.layer.data.shape[i]) + + def get_xy(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Get data for plotting. + """ + x = np.arange(self.layer.data.shape[self.current_dim_index]) + + vals = self.selector_values + vals.update({"z": self.current_z}) + + slices = [] + for d in _dims: + if d == self.current_dim: + # Select all data along this axis + slices.append(slice(None)) + else: + # Select specific index + val = vals[d] + slices.append(slice(val, val + 1)) + + # Reverse since z is the first axis in napari + slices = slices[::-1] + y = self.layer.data[tuple(slices)].ravel() + + return x, y + + def clear(self) -> None: + self.axes.cla() + + def draw(self) -> None: + """ + Clear axes and draw a 1D plot. + """ + x, y = self.get_xy() + + self.axes.plot(x, y) + self.axes.set_xlabel(self.current_dim) + self.axes.set_title(self.layer.name) diff --git a/src/napari_matplotlib/tests/test_slice.py b/src/napari_matplotlib/tests/test_slice.py new file mode 100644 index 00000000..d0be3cc1 --- /dev/null +++ b/src/napari_matplotlib/tests/test_slice.py @@ -0,0 +1,10 @@ +import numpy as np + +from napari_matplotlib import SliceWidget + + +def test_scatter(make_napari_viewer): + # Smoke test adding a histogram widget + viewer = make_napari_viewer() + viewer.add_image(np.random.random((100, 100, 100))) + SliceWidget(viewer)