1- from typing import Any , Dict , Optional , Tuple
1+ from typing import Any , Dict , List , Optional , Tuple
22
33import matplotlib .ticker as mticker
44import napari
1212__all__ = ["SliceWidget" ]
1313
1414_dims_sel = ["x" , "y" ]
15- _dims = ["x" , "y" , "z" ]
1615
1716
1817class SliceWidget (SingleAxesWidget ):
@@ -37,7 +36,7 @@ def __init__(
3736 self .dim_selector = QComboBox ()
3837 button_layout .addWidget (QLabel ("Slice axis:" ))
3938 button_layout .addWidget (self .dim_selector )
40- self .dim_selector .addItems (_dims )
39+ self .dim_selector .addItems ([ "x" , "y" , "z" ] )
4140
4241 self .slice_selectors = {}
4342 for d in _dims_sel :
@@ -61,7 +60,7 @@ def _layer(self) -> napari.layers.Layer:
6160 return self .layers [0 ]
6261
6362 @property
64- def current_dim (self ) -> str :
63+ def current_dim_name (self ) -> str :
6564 """
6665 Currently selected slice dimension.
6766 """
@@ -74,32 +73,50 @@ def current_dim_index(self) -> int:
7473 """
7574 # Note the reversed list because in napari the z-axis is the first
7675 # numpy axis
77- return _dims [::- 1 ].index (self .current_dim )
76+ return self ._dim_names [::- 1 ].index (self .current_dim_name )
77+
78+ @property
79+ def _dim_names (self ) -> List [str ]:
80+ """
81+ List of dimension names. This is a property as it varies depending on the
82+ dimensionality of the currently selected data.
83+ """
84+ if self ._layer .data .ndim == 2 :
85+ return ["x" , "y" ]
86+ elif self ._layer .data .ndim == 3 :
87+ return ["x" , "y" , "z" ]
88+ else :
89+ raise RuntimeError ("Don't know how to handle ndim != 2 or 3" )
7890
7991 @property
8092 def _selector_values (self ) -> Dict [str , int ]:
8193 """
8294 Values of the slice selectors.
95+
96+ Mapping from dimension name to value.
8397 """
8498 return {d : self .slice_selectors [d ].value () for d in _dims_sel }
8599
86100 def _get_xy (self ) -> Tuple [npt .NDArray [Any ], npt .NDArray [Any ]]:
87101 """
88102 Get data for plotting.
89103 """
90- x = np .arange (self ._layer .data .shape [self .current_dim_index ])
104+ dim_index = self .current_dim_index
105+ if self ._layer .data .ndim == 2 :
106+ dim_index -= 1
107+ x = np .arange (self ._layer .data .shape [dim_index ])
91108
92109 vals = self ._selector_values
93110 vals .update ({"z" : self .current_z })
94111
95112 slices = []
96- for d in _dims :
97- if d == self .current_dim :
113+ for dim_name in self . _dim_names :
114+ if dim_name == self .current_dim_name :
98115 # Select all data along this axis
99116 slices .append (slice (None ))
100117 else :
101118 # Select specific index
102- val = vals [d ]
119+ val = vals [dim_name ]
103120 slices .append (slice (val , val + 1 ))
104121
105122 # Reverse since z is the first axis in napari
@@ -115,7 +132,7 @@ def draw(self) -> None:
115132 x , y = self ._get_xy ()
116133
117134 self .axes .plot (x , y )
118- self .axes .set_xlabel (self .current_dim )
135+ self .axes .set_xlabel (self .current_dim_name )
119136 self .axes .set_title (self ._layer .name )
120137 # Make sure all ticks lie on integer values
121138 self .axes .xaxis .set_major_locator (
0 commit comments