Skip to content

Commit d1ee730

Browse files
committed
STY: Remove dicts in favor of lists
1 parent 5991576 commit d1ee730

File tree

2 files changed

+90
-97
lines changed

2 files changed

+90
-97
lines changed

nibabel/tests/test_viewers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def test_viewer():
3838

3939
# fake some events, inside and outside axes
4040
v._on_scroll(nt('event', 'button inaxes key')('up', None, None))
41-
for ax in (v._axes['x'], v._axes['v']):
41+
for ax in (v._axes[0], v._axes[3]):
4242
v._on_scroll(nt('event', 'button inaxes key')('up', ax, None))
4343
v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift'))
4444
# "click" outside axes, then once in each axis, then move without click
4545
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1))
46-
for ax in v._axes.values():
46+
for ax in v._axes:
4747
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1))
4848
v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None))
4949
v.set_volume_idx(1)
@@ -52,7 +52,7 @@ def test_viewer():
5252

5353
# non-multi-volume
5454
v = OrthoSlicer3D(data[:, :, :, 0])
55-
v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift'))
55+
v._on_scroll(nt('event', 'button inaxes key')('up', v._axes[0], 'shift'))
5656
v._on_keypress(nt('event', 'key')('escape'))
5757

5858
# other cases

nibabel/viewers.py

Lines changed: 87 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
7070
# determine our orientation
7171
self._affine = affine.copy()
7272
codes = axcodes2ornt(aff2axcodes(self._affine))
73-
order = np.argsort([c[0] for c in codes])
74-
flips = np.array([c[1] < 0 for c in codes])[order]
75-
self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2]))
76-
self._flips = dict(x=flips[0], y=flips[1], z=flips[2])
73+
self._order = np.argsort([c[0] for c in codes])
74+
self._flips = np.array([c[1] < 0 for c in codes])[self._order]
75+
self._flips = list(self._flips) + [False] # add volume dim
7776
self._scalers = np.abs(self._affine).max(axis=0)[:3]
7877
self._inv_affine = np.linalg.inv(affine)
7978
# current volume info
@@ -87,56 +86,54 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
8786
# ^ +---------+ ^ +---------+
8887
# | | | | | |
8988
# | Sag | | Cor |
90-
# S | 1 | S | 2 |
89+
# S | 0 | S | 1 |
9190
# | | | |
9291
# | | | |
9392
# +---------+ +---------+
9493
# A --> <-- R
9594
# ^ +---------+ +---------+
9695
# | | | | |
9796
# | Axial | | Vol |
98-
# A | 3 | | 4 |
97+
# A | 2 | | 3 |
9998
# | | | |
10099
# | | | |
101100
# +---------+ +---------+
102101
# <-- R <-- t -->
103102

104103
fig, axes = plt.subplots(2, 2)
105104
fig.set_size_inches(figsize, forward=True)
106-
self._axes = dict(x=axes[0, 0], y=axes[0, 1], z=axes[1, 0],
107-
v=axes[1, 1])
105+
self._axes = [axes[0, 0], axes[0, 1], axes[1, 0], axes[1, 1]]
108106
plt.tight_layout(pad=0.1)
109107
if self.n_volumes <= 1:
110-
fig.delaxes(self._axes['v'])
111-
del self._axes['v']
108+
fig.delaxes(self._axes[3])
109+
self._axes.pop(-1)
112110
else:
113-
self._axes = dict(z=axes[0], y=axes[1], x=axes[2])
111+
self._axes = [axes[0], axes[1], axes[2]]
114112
if len(axes) > 3:
115-
self._axes['v'] = axes[3]
113+
self._axes.append(axes[3])
116114

117115
# Start midway through each axis, idx is current slice number
118-
self._ims, self._sizes, self._data_idx = dict(), dict(), dict()
116+
self._ims, self._data_idx = list(), list()
119117

120118
# set up axis crosshairs
121-
self._crosshairs = dict()
122-
r = [self._scalers[self._order['z']] / self._scalers[self._order['y']],
123-
self._scalers[self._order['z']] / self._scalers[self._order['x']],
124-
self._scalers[self._order['y']] / self._scalers[self._order['x']]]
125-
for k in 'xyz':
126-
self._sizes[k] = self._data.shape[self._order[k]]
127-
for k, xax, yax, ratio, label in zip('xyz', 'yxx', 'zzy', r,
128-
('SAIP', 'SLIR', 'ALPR')):
129-
ax = self._axes[k]
119+
self._crosshairs = [None] * 3
120+
r = [self._scalers[self._order[2]] / self._scalers[self._order[1]],
121+
self._scalers[self._order[2]] / self._scalers[self._order[0]],
122+
self._scalers[self._order[1]] / self._scalers[self._order[0]]]
123+
self._sizes = [self._data.shape[o] for o in self._order]
124+
for ii, xax, yax, ratio, label in zip([0, 1, 2], [1, 0, 0], [2, 2, 1],
125+
r, ('SAIP', 'SLIR', 'ALPR')):
126+
ax = self._axes[ii]
130127
d = np.zeros((self._sizes[yax], self._sizes[xax]))
131-
self._ims[k] = self._axes[k].imshow(d, vmin=vmin, vmax=vmax,
132-
aspect=1, cmap=cmap,
133-
interpolation='nearest',
134-
origin='lower')
128+
im = self._axes[ii].imshow(d, vmin=vmin, vmax=vmax, aspect=1,
129+
cmap=cmap, interpolation='nearest',
130+
origin='lower')
131+
self._ims.append(im)
135132
vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5],
136133
color=(0, 1, 0), linestyle='-')[0]
137134
horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2,
138135
color=(0, 1, 0), linestyle='-')[0]
139-
self._crosshairs[k] = dict(vert=vert, horiz=horiz)
136+
self._crosshairs[ii] = dict(vert=vert, horiz=horiz)
140137
# add text labels (top, right, bottom, left)
141138
lims = [0, self._sizes[xax], 0, self._sizes[yax]]
142139
bump = 0.01
@@ -156,12 +153,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
156153
ax.set_frame_on(False)
157154
ax.axes.get_yaxis().set_visible(False)
158155
ax.axes.get_xaxis().set_visible(False)
159-
self._data_idx[k] = 0
160-
self._data_idx['v'] = -1
156+
self._data_idx.append(0)
157+
self._data_idx.append(-1) # volume
161158

162159
# Set up volumes axis
163-
if self.n_volumes > 1 and 'v' in self._axes:
164-
ax = self._axes['v']
160+
if self.n_volumes > 1 and len(self._axes) > 3:
161+
ax = self._axes[3]
165162
ax.set_axis_bgcolor('k')
166163
ax.set_title('Volumes')
167164
y = np.zeros(self.n_volumes + 1)
@@ -179,7 +176,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
179176
ax.set_ylim(yl)
180177
self._volume_ax_objs = dict(step=step, patch=patch)
181178

182-
self._figs = set([a.figure for a in self._axes.values()])
179+
self._figs = set([a.figure for a in self._axes])
183180
for fig in self._figs:
184181
fig.canvas.mpl_connect('scroll_event', self._on_scroll)
185182
fig.canvas.mpl_connect('motion_notify_event', self._on_mouse)
@@ -287,14 +284,14 @@ def set_volume_idx(self, v):
287284

288285
def _set_volume_index(self, v, update_slices=True):
289286
"""Set the plot data using a volume index"""
290-
v = self._data_idx['v'] if v is None else int(round(v))
291-
if v == self._data_idx['v']:
287+
v = self._data_idx[3] if v is None else int(round(v))
288+
if v == self._data_idx[3]:
292289
return
293290
max_ = np.prod(self._volume_dims)
294-
self._data_idx['v'] = max(min(int(round(v)), max_ - 1), 0)
291+
self._data_idx[3] = max(min(int(round(v)), max_ - 1), 0)
295292
idx = (slice(None), slice(None), slice(None))
296293
if self._data.ndim > 3:
297-
idx = idx + tuple(np.unravel_index(self._data_idx['v'],
294+
idx = idx + tuple(np.unravel_index(self._data_idx[3],
298295
self._volume_dims))
299296
self._current_vol_data = self._data[idx]
300297
# update all of our slice plots
@@ -314,108 +311,104 @@ def _set_position(self, x, y, z, notify=True):
314311
# deal with slicing appropriately
315312
self._position[:3] = [x, y, z]
316313
idxs = np.dot(self._inv_affine, self._position)[:3]
317-
for key, idx in zip('xyz', idxs):
318-
self._data_idx[key] = max(min(int(round(idx)),
319-
self._sizes[key] - 1), 0)
320-
for key in 'xyz':
314+
for ii, (size, idx) in enumerate(zip(self._sizes, idxs)):
315+
self._data_idx[ii] = max(min(int(round(idx)), size - 1), 0)
316+
for ii in range(3):
321317
# saggital: get to S/A
322318
# coronal: get to S/L
323319
# axial: get to A/L
324-
data = np.take(self._current_vol_data, self._data_idx[key],
325-
axis=self._order[key])
326-
xax = dict(x='y', y='x', z='x')[key]
327-
yax = dict(x='z', y='z', z='y')[key]
320+
data = np.take(self._current_vol_data, self._data_idx[ii],
321+
axis=self._order[ii])
322+
xax = [1, 0, 0][ii]
323+
yax = [2, 2, 1][ii]
328324
if self._order[xax] < self._order[yax]:
329325
data = data.T
330326
if self._flips[xax]:
331327
data = data[:, ::-1]
332328
if self._flips[yax]:
333329
data = data[::-1]
334-
self._ims[key].set_data(data)
330+
self._ims[ii].set_data(data)
335331
# deal with crosshairs
336-
loc = self._data_idx[key]
337-
if self._flips[key]:
338-
loc = self._sizes[key] - loc
332+
loc = self._data_idx[ii]
333+
if self._flips[ii]:
334+
loc = self._sizes[ii] - loc
339335
loc = [loc] * 2
340-
if key == 'x':
341-
self._crosshairs['z']['vert'].set_xdata(loc)
342-
self._crosshairs['y']['vert'].set_xdata(loc)
343-
elif key == 'y':
344-
self._crosshairs['z']['horiz'].set_ydata(loc)
345-
self._crosshairs['x']['vert'].set_xdata(loc)
346-
else: # key == 'z'
347-
self._crosshairs['y']['horiz'].set_ydata(loc)
348-
self._crosshairs['x']['horiz'].set_ydata(loc)
336+
if ii == 0:
337+
self._crosshairs[2]['vert'].set_xdata(loc)
338+
self._crosshairs[1]['vert'].set_xdata(loc)
339+
elif ii == 1:
340+
self._crosshairs[2]['horiz'].set_ydata(loc)
341+
self._crosshairs[0]['vert'].set_xdata(loc)
342+
else: # ii == 2
343+
self._crosshairs[1]['horiz'].set_ydata(loc)
344+
self._crosshairs[0]['horiz'].set_ydata(loc)
349345

350346
# Update volume trace
351-
if self.n_volumes > 1 and 'v' in self._axes:
352-
idx = [0] * 3
353-
for key in 'xyz':
354-
idx[self._order[key]] = self._data_idx[key]
355-
vdata = self._data[idx[0], idx[1], idx[2], :].ravel()
347+
if self.n_volumes > 1 and len(self._axes) > 3:
348+
idx = [None, Ellipsis] * 3
349+
for ii in range(3):
350+
idx[self._order[ii]] = self._data_idx[ii]
351+
vdata = self._data[idx].ravel()
356352
vdata = np.concatenate((vdata, [vdata[-1]]))
357-
self._volume_ax_objs['patch'].set_x(self._data_idx['v'] - 0.5)
353+
self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5)
358354
self._volume_ax_objs['step'].set_ydata(vdata)
359355
if notify:
360356
self._notify_links()
361357
self._changing = False
362358

363359
# Matplotlib handlers ####################################################
364360
def _in_axis(self, event):
365-
"""Return axis key if within one of our axes, else None"""
361+
"""Return axis index if within one of our axes, else None"""
366362
if getattr(event, 'inaxes') is None:
367363
return None
368-
for key, ax in self._axes.items():
364+
for ii, ax in enumerate(self._axes):
369365
if event.inaxes is ax:
370-
return key
366+
return ii
371367

372368
def _on_scroll(self, event):
373369
"""Handle mpl scroll wheel event"""
374370
assert event.button in ('up', 'down')
375-
key = self._in_axis(event)
376-
if key is None:
371+
ii = self._in_axis(event)
372+
if ii is None:
377373
return
378374
if event.key is not None and 'shift' in event.key:
379375
if self.n_volumes <= 1:
380376
return
381-
key = 'v' # shift: change volume in any axis
382-
assert key in ['x', 'y', 'z', 'v']
377+
ii = 3 # shift: change volume in any axis
378+
assert ii in range(4)
383379
dv = 10. if event.key is not None and 'control' in event.key else 1.
384380
dv *= 1. if event.button == 'up' else -1.
385-
dv *= -1 if self._flips.get(key, False) else 1
386-
val = self._data_idx[key] + dv
387-
if key == 'v':
381+
dv *= -1 if self._flips[ii] else 1
382+
val = self._data_idx[ii] + dv
383+
if ii == 3:
388384
self._set_volume_index(val)
389385
else:
390-
coords = {key: val}
391-
for k in 'xyz':
392-
if k not in coords:
393-
coords[k] = self._data_idx[k]
394-
coords = np.array([coords['x'], coords['y'], coords['z'], 1.])
395-
coords = np.dot(self._affine, coords)[:3]
396-
self._set_position(coords[0], coords[1], coords[2])
386+
coords = [self._data_idx[k] for k in range(3)] + [1.]
387+
coords[ii] = val
388+
self._set_position(*np.dot(self._affine, coords)[:3])
397389
self._draw()
398390

399391
def _on_mouse(self, event):
400392
"""Handle mpl mouse move and button press events"""
401393
if event.button != 1: # only enabled while dragging
402394
return
403-
key = self._in_axis(event)
404-
if key is None:
395+
ii = self._in_axis(event)
396+
if ii is None:
405397
return
406-
if key == 'v':
398+
if ii == 3:
407399
# volume plot directly translates
408400
self._set_volume_index(event.xdata)
409401
else:
410402
# translate click xdata/ydata to physical position
411-
xax, yax = dict(x='yz', y='xz', z='xy')[key]
403+
xax, yax = [[1, 2], [0, 2], [0, 1]][ii]
412404
x, y = event.xdata, event.ydata
413405
x = self._sizes[xax] - x if self._flips[xax] else x
414406
y = self._sizes[yax] - y if self._flips[yax] else y
415-
idxs = {xax: x, yax: y, key: self._data_idx[key]}
416-
idxs = np.array([idxs['x'], idxs['y'], idxs['z'], 1.])
417-
pos = np.dot(self._affine, idxs)[:3]
418-
self._set_position(*pos)
407+
idxs = [None, None, None, 1.]
408+
idxs[xax] = x
409+
idxs[yax] = y
410+
idxs[ii] = self._data_idx[ii]
411+
self._set_position(*np.dot(self._affine, idxs)[:3])
419412
self._draw()
420413

421414
def _on_keypress(self, event):
@@ -425,14 +418,14 @@ def _on_keypress(self, event):
425418

426419
def _draw(self):
427420
"""Update all four (or three) plots"""
428-
for key in 'xyz':
429-
ax, im = self._axes[key], self._ims[key]
430-
ax.draw_artist(im)
431-
for line in self._crosshairs[key].values():
421+
for ii in range(3):
422+
ax = self._axes[ii]
423+
ax.draw_artist(self._ims[ii])
424+
for line in self._crosshairs[ii].values():
432425
ax.draw_artist(line)
433426
ax.figure.canvas.blit(ax.bbox)
434-
if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3
435-
ax = self._axes['v']
427+
if self.n_volumes > 1 and len(self._axes) > 3:
428+
ax = self._axes[3]
436429
ax.draw_artist(ax.patch) # axis bgcolor to erase old lines
437430
for key in ('step', 'patch'):
438431
ax.draw_artist(self._volume_ax_objs[key])

0 commit comments

Comments
 (0)