diff --git a/pyproject.toml b/pyproject.toml index aec6e9af..1944249c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ fix = true convention = "numpy" [tool.mypy] +python_version = "3.8" # Block below are checks that form part of mypy 'strict' mode warn_unused_configs = true warn_redundant_casts = true @@ -51,6 +52,7 @@ disallow_incomplete_defs = true disallow_untyped_defs = true no_implicit_reexport = true warn_return_any = false # TODO: fix +ignore_missing_imports = true [[tool.mypy.overrides]] module = [ diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index b69d0310..0368e7ef 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple import napari from matplotlib.axes import Axes @@ -43,8 +43,12 @@ class NapariMPLWidget(QWidget): List of currently selected napari layers. """ - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__() + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(parent=parent) self.viewer = napari_viewer self.canvas = FigureCanvas() @@ -52,7 +56,7 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): self.canvas.figure.patch.set_facecolor("none") self.canvas.figure.set_layout_engine("constrained") self.toolbar = NapariNavigationToolbar( - self.canvas, self + self.canvas, parent=self ) # type: ignore[no-untyped-call] self._replace_toolbar_icons() diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 7e863826..04d91ed1 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,13 +1,14 @@ +from typing import Optional + +import napari import numpy as np +from qtpy.QtWidgets import QWidget from .base import NapariMPLWidget +from .util import Interval __all__ = ["HistogramWidget"] -import napari - -from .util import Interval - _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -19,8 +20,12 @@ class HistogramWidget(NapariMPLWidget): n_layers_input = Interval(1, 1) input_layer_types = (napari.layers.Image,) - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) self.add_single_axes() self.update_layers(None) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 405b7b09..3cc7b169 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,9 +1,8 @@ -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import napari import numpy.typing as npt -from magicgui import magicgui -from magicgui.widgets import ComboBox +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import NapariMPLWidget from .util import Interval @@ -20,11 +19,13 @@ class ScatterBaseWidget(NapariMPLWidget): # the scatter is plotted as a 2D histogram _threshold_to_switch_to_histogram = 500 - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) - + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) self.add_single_axes() - self.update_layers(None) def clear(self) -> None: """ @@ -113,55 +114,57 @@ class FeaturesScatterWidget(ScatterBaseWidget): napari.layers.Vectors, ) - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) - self._key_selection_widget = magicgui( - self._set_axis_keys, - x_axis_key={"choices": self._get_valid_axis_keys}, - y_axis_key={"choices": self._get_valid_axis_keys}, - call_button="plot", - ) + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) + + self.layout().addLayout(QVBoxLayout()) - self.layout().addWidget(self._key_selection_widget.native) + self._selectors: Dict[str, QComboBox] = {} + for dim in ["x", "y"]: + self._selectors[dim] = QComboBox() + # Re-draw when combo boxes are updated + self._selectors[dim].currentTextChanged.connect(self._draw) + + self.layout().addWidget(QLabel(f"{dim}-axis:")) + self.layout().addWidget(self._selectors[dim]) + + self.update_layers(None) @property - def x_axis_key(self) -> Optional[str]: + def x_axis_key(self) -> Union[str, None]: """ Key to access x axis data from the FeaturesTable. """ - return self._x_axis_key + if self._selectors["x"].count() == 0: + return None + else: + return self._selectors["x"].currentText() @x_axis_key.setter - def x_axis_key(self, key: Optional[str]) -> None: - self._x_axis_key = key + def x_axis_key(self, key: str) -> None: + self._selectors["x"].setCurrentText(key) self._draw() @property - def y_axis_key(self) -> Optional[str]: + def y_axis_key(self) -> Union[str, None]: """ Key to access y axis data from the FeaturesTable. """ - return self._y_axis_key + if self._selectors["y"].count() == 0: + return None + else: + return self._selectors["y"].currentText() @y_axis_key.setter - def y_axis_key(self, key: Optional[str]) -> None: - """ - Set the y-axis key. - """ - self._y_axis_key = key - self._draw() - - def _set_axis_keys(self, x_axis_key: str, y_axis_key: str) -> None: - """ - Set both axis keys and then redraw the plot. - """ - self._x_axis_key = x_axis_key - self._y_axis_key = y_axis_key + def y_axis_key(self, key: str) -> None: + self._selectors["y"].setCurrentText(key) self._draw() - def _get_valid_axis_keys( - self, combo_widget: Optional[ComboBox] = None - ) -> List[str]: + def _get_valid_axis_keys(self) -> List[str]: """ Get the valid axis keys from the layer FeatureTable. @@ -185,11 +188,12 @@ def _ready_to_scatter(self) -> bool: return False feature_table = self.layers[0].features + valid_keys = self._get_valid_axis_keys() return ( feature_table is not None and len(feature_table) > 0 - and self.x_axis_key is not None - and self.y_axis_key is not None + and self.x_axis_key in valid_keys + and self.y_axis_key in valid_keys ) def draw(self) -> None: @@ -230,9 +234,9 @@ def _on_update_layers(self) -> None: """ Called when the layer selection changes by ``self.update_layers()``. """ - if hasattr(self, "_key_selection_widget"): - self._key_selection_widget.reset_choices() - - # reset the axis keys - self._x_axis_key = None - self._y_axis_key = None + # Clear combobox + for dim in ["x", "y"]: + while self._selectors[dim].count() > 0: + self._selectors[dim].removeItem(0) + # Add keys for newly selected layer + self._selectors[dim].addItems(self._get_valid_axis_keys()) diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index 4e22bad8..78e7841c 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import napari import numpy as np import numpy.typing as npt -from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox +from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget from .base import NapariMPLWidget from .util import Interval @@ -22,9 +22,13 @@ class SliceWidget(NapariMPLWidget): n_layers_input = Interval(1, 1) input_layer_types = (napari.layers.Image,) - def __init__(self, napari_viewer: napari.viewer.Viewer): + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): # Setup figure/axes - super().__init__(napari_viewer) + super().__init__(napari_viewer, parent=parent) self.add_single_axes() button_layout = QHBoxLayout() diff --git a/src/napari_matplotlib/tests/conftest.py b/src/napari_matplotlib/tests/conftest.py index 06ed51f6..0788292d 100644 --- a/src/napari_matplotlib/tests/conftest.py +++ b/src/napari_matplotlib/tests/conftest.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from skimage import data @@ -22,3 +24,17 @@ def astronaut_data(): @pytest.fixture def brain_data(): return data.brain(), {"rgb": False} + + +@pytest.fixture(autouse=True, scope="session") +def set_strict_qt(): + env_var = "NAPARI_STRICT_QT" + old_val = os.environ.get(env_var) + os.environ[env_var] = "1" + # Run tests + yield + # Reset to original value + if old_val is not None: + os.environ[env_var] = old_val + else: + del os.environ[env_var] diff --git a/tox.ini b/tox.ini index 5a8cf188..6f7a4823 100644 --- a/tox.ini +++ b/tox.ini @@ -12,5 +12,5 @@ python = [testenv] extras = testing commands = - - python -c 'from skimage import data; data.brain()' - - python -m pytest --mpl -v --color=yes --cov=napari_matplotlib --cov-report=xml + python -c 'from skimage import data; data.brain()' + python -m pytest --mpl -v --color=yes --cov=napari_matplotlib --cov-report=xml