@@ -70,10 +70,9 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
70
70
# determine our orientation
71
71
self ._affine = affine .copy ()
72
72
codes = axcodes2ornt (aff2axcodes (self ._affine ))
73
- order = np .argsort ([c [0 ] for c in codes ])
74
- flips = np .array ([c [1 ] < 0 for c in codes ])[order ]
75
- self ._order = dict (x = int (order [0 ]), y = int (order [1 ]), z = int (order [2 ]))
76
- self ._flips = dict (x = flips [0 ], y = flips [1 ], z = flips [2 ])
73
+ self ._order = np .argsort ([c [0 ] for c in codes ])
74
+ self ._flips = np .array ([c [1 ] < 0 for c in codes ])[self ._order ]
75
+ self ._flips = list (self ._flips ) + [False ] # add volume dim
77
76
self ._scalers = np .abs (self ._affine ).max (axis = 0 )[:3 ]
78
77
self ._inv_affine = np .linalg .inv (affine )
79
78
# current volume info
@@ -87,56 +86,54 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
87
86
# ^ +---------+ ^ +---------+
88
87
# | | | | | |
89
88
# | Sag | | Cor |
90
- # S | 1 | S | 2 |
89
+ # S | 0 | S | 1 |
91
90
# | | | |
92
91
# | | | |
93
92
# +---------+ +---------+
94
93
# A --> <-- R
95
94
# ^ +---------+ +---------+
96
95
# | | | | |
97
96
# | Axial | | Vol |
98
- # A | 3 | | 4 |
97
+ # A | 2 | | 3 |
99
98
# | | | |
100
99
# | | | |
101
100
# +---------+ +---------+
102
101
# <-- R <-- t -->
103
102
104
103
fig , axes = plt .subplots (2 , 2 )
105
104
fig .set_size_inches (figsize , forward = True )
106
- self ._axes = dict (x = axes [0 , 0 ], y = axes [0 , 1 ], z = axes [1 , 0 ],
107
- v = axes [1 , 1 ])
105
+ self ._axes = [axes [0 , 0 ], axes [0 , 1 ], axes [1 , 0 ], axes [1 , 1 ]]
108
106
plt .tight_layout (pad = 0.1 )
109
107
if self .n_volumes <= 1 :
110
- fig .delaxes (self ._axes ['v' ])
111
- del self ._axes [ 'v' ]
108
+ fig .delaxes (self ._axes [3 ])
109
+ self ._axes . pop ( - 1 )
112
110
else :
113
- self ._axes = dict ( z = axes [0 ], y = axes [1 ], x = axes [2 ])
111
+ self ._axes = [ axes [0 ], axes [1 ], axes [2 ]]
114
112
if len (axes ) > 3 :
115
- self ._axes [ 'v' ] = axes [3 ]
113
+ self ._axes . append ( axes [3 ])
116
114
117
115
# Start midway through each axis, idx is current slice number
118
- self ._ims , self ._sizes , self . _data_idx = dict (), dict (), dict ()
116
+ self ._ims , self ._data_idx = list (), list ()
119
117
120
118
# set up axis crosshairs
121
- self ._crosshairs = dict ()
122
- r = [self ._scalers [self ._order ['z' ]] / self ._scalers [self ._order ['y' ]],
123
- self ._scalers [self ._order ['z' ]] / self ._scalers [self ._order ['x' ]],
124
- self ._scalers [self ._order ['y' ]] / self ._scalers [self ._order ['x' ]]]
125
- for k in 'xyz' :
126
- self ._sizes [k ] = self ._data .shape [self ._order [k ]]
127
- for k , xax , yax , ratio , label in zip ('xyz' , 'yxx' , 'zzy' , r ,
128
- ('SAIP' , 'SLIR' , 'ALPR' )):
129
- ax = self ._axes [k ]
119
+ self ._crosshairs = [None ] * 3
120
+ r = [self ._scalers [self ._order [2 ]] / self ._scalers [self ._order [1 ]],
121
+ self ._scalers [self ._order [2 ]] / self ._scalers [self ._order [0 ]],
122
+ self ._scalers [self ._order [1 ]] / self ._scalers [self ._order [0 ]]]
123
+ self ._sizes = [self ._data .shape [o ] for o in self ._order ]
124
+ for ii , xax , yax , ratio , label in zip ([0 , 1 , 2 ], [1 , 0 , 0 ], [2 , 2 , 1 ],
125
+ r , ('SAIP' , 'SLIR' , 'ALPR' )):
126
+ ax = self ._axes [ii ]
130
127
d = np .zeros ((self ._sizes [yax ], self ._sizes [xax ]))
131
- self . _ims [ k ] = self ._axes [k ].imshow (d , vmin = vmin , vmax = vmax ,
132
- aspect = 1 , cmap = cmap ,
133
- interpolation = 'nearest' ,
134
- origin = 'lower' )
128
+ im = self ._axes [ii ].imshow (d , vmin = vmin , vmax = vmax , aspect = 1 ,
129
+ cmap = cmap , interpolation = 'nearest' ,
130
+ origin = 'lower' )
131
+ self . _ims . append ( im )
135
132
vert = ax .plot ([0 ] * 2 , [- 0.5 , self ._sizes [yax ] - 0.5 ],
136
133
color = (0 , 1 , 0 ), linestyle = '-' )[0 ]
137
134
horiz = ax .plot ([- 0.5 , self ._sizes [xax ] - 0.5 ], [0 ] * 2 ,
138
135
color = (0 , 1 , 0 ), linestyle = '-' )[0 ]
139
- self ._crosshairs [k ] = dict (vert = vert , horiz = horiz )
136
+ self ._crosshairs [ii ] = dict (vert = vert , horiz = horiz )
140
137
# add text labels (top, right, bottom, left)
141
138
lims = [0 , self ._sizes [xax ], 0 , self ._sizes [yax ]]
142
139
bump = 0.01
@@ -156,12 +153,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
156
153
ax .set_frame_on (False )
157
154
ax .axes .get_yaxis ().set_visible (False )
158
155
ax .axes .get_xaxis ().set_visible (False )
159
- self ._data_idx [ k ] = 0
160
- self ._data_idx [ 'v' ] = - 1
156
+ self ._data_idx . append ( 0 )
157
+ self ._data_idx . append ( - 1 ) # volume
161
158
162
159
# Set up volumes axis
163
- if self .n_volumes > 1 and 'v' in self ._axes :
164
- ax = self ._axes ['v' ]
160
+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
161
+ ax = self ._axes [3 ]
165
162
ax .set_axis_bgcolor ('k' )
166
163
ax .set_title ('Volumes' )
167
164
y = np .zeros (self .n_volumes + 1 )
@@ -179,7 +176,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
179
176
ax .set_ylim (yl )
180
177
self ._volume_ax_objs = dict (step = step , patch = patch )
181
178
182
- self ._figs = set ([a .figure for a in self ._axes . values () ])
179
+ self ._figs = set ([a .figure for a in self ._axes ])
183
180
for fig in self ._figs :
184
181
fig .canvas .mpl_connect ('scroll_event' , self ._on_scroll )
185
182
fig .canvas .mpl_connect ('motion_notify_event' , self ._on_mouse )
@@ -287,14 +284,14 @@ def set_volume_idx(self, v):
287
284
288
285
def _set_volume_index (self , v , update_slices = True ):
289
286
"""Set the plot data using a volume index"""
290
- v = self ._data_idx ['v' ] if v is None else int (round (v ))
291
- if v == self ._data_idx ['v' ]:
287
+ v = self ._data_idx [3 ] if v is None else int (round (v ))
288
+ if v == self ._data_idx [3 ]:
292
289
return
293
290
max_ = np .prod (self ._volume_dims )
294
- self ._data_idx ['v' ] = max (min (int (round (v )), max_ - 1 ), 0 )
291
+ self ._data_idx [3 ] = max (min (int (round (v )), max_ - 1 ), 0 )
295
292
idx = (slice (None ), slice (None ), slice (None ))
296
293
if self ._data .ndim > 3 :
297
- idx = idx + tuple (np .unravel_index (self ._data_idx ['v' ],
294
+ idx = idx + tuple (np .unravel_index (self ._data_idx [3 ],
298
295
self ._volume_dims ))
299
296
self ._current_vol_data = self ._data [idx ]
300
297
# update all of our slice plots
@@ -314,108 +311,104 @@ def _set_position(self, x, y, z, notify=True):
314
311
# deal with slicing appropriately
315
312
self ._position [:3 ] = [x , y , z ]
316
313
idxs = np .dot (self ._inv_affine , self ._position )[:3 ]
317
- for key , idx in zip ('xyz' , idxs ):
318
- self ._data_idx [key ] = max (min (int (round (idx )),
319
- self ._sizes [key ] - 1 ), 0 )
320
- for key in 'xyz' :
314
+ for ii , (size , idx ) in enumerate (zip (self ._sizes , idxs )):
315
+ self ._data_idx [ii ] = max (min (int (round (idx )), size - 1 ), 0 )
316
+ for ii in range (3 ):
321
317
# saggital: get to S/A
322
318
# coronal: get to S/L
323
319
# axial: get to A/L
324
- data = np .take (self ._current_vol_data , self ._data_idx [key ],
325
- axis = self ._order [key ])
326
- xax = dict ( x = 'y' , y = 'x' , z = 'x' )[ key ]
327
- yax = dict ( x = 'z' , y = 'z' , z = 'y' )[ key ]
320
+ data = np .take (self ._current_vol_data , self ._data_idx [ii ],
321
+ axis = self ._order [ii ])
322
+ xax = [ 1 , 0 , 0 ][ ii ]
323
+ yax = [ 2 , 2 , 1 ][ ii ]
328
324
if self ._order [xax ] < self ._order [yax ]:
329
325
data = data .T
330
326
if self ._flips [xax ]:
331
327
data = data [:, ::- 1 ]
332
328
if self ._flips [yax ]:
333
329
data = data [::- 1 ]
334
- self ._ims [key ].set_data (data )
330
+ self ._ims [ii ].set_data (data )
335
331
# deal with crosshairs
336
- loc = self ._data_idx [key ]
337
- if self ._flips [key ]:
338
- loc = self ._sizes [key ] - loc
332
+ loc = self ._data_idx [ii ]
333
+ if self ._flips [ii ]:
334
+ loc = self ._sizes [ii ] - loc
339
335
loc = [loc ] * 2
340
- if key == 'x' :
341
- self ._crosshairs ['z' ]['vert' ].set_xdata (loc )
342
- self ._crosshairs ['y' ]['vert' ].set_xdata (loc )
343
- elif key == 'y' :
344
- self ._crosshairs ['z' ]['horiz' ].set_ydata (loc )
345
- self ._crosshairs ['x' ]['vert' ].set_xdata (loc )
346
- else : # key == 'z'
347
- self ._crosshairs ['y' ]['horiz' ].set_ydata (loc )
348
- self ._crosshairs ['x' ]['horiz' ].set_ydata (loc )
336
+ if ii == 0 :
337
+ self ._crosshairs [2 ]['vert' ].set_xdata (loc )
338
+ self ._crosshairs [1 ]['vert' ].set_xdata (loc )
339
+ elif ii == 1 :
340
+ self ._crosshairs [2 ]['horiz' ].set_ydata (loc )
341
+ self ._crosshairs [0 ]['vert' ].set_xdata (loc )
342
+ else : # ii == 2
343
+ self ._crosshairs [1 ]['horiz' ].set_ydata (loc )
344
+ self ._crosshairs [0 ]['horiz' ].set_ydata (loc )
349
345
350
346
# Update volume trace
351
- if self .n_volumes > 1 and 'v' in self ._axes :
352
- idx = [0 ] * 3
353
- for key in 'xyz' :
354
- idx [self ._order [key ]] = self ._data_idx [key ]
355
- vdata = self ._data [idx [ 0 ], idx [ 1 ], idx [ 2 ], : ].ravel ()
347
+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
348
+ idx = [None , Ellipsis ] * 3
349
+ for ii in range ( 3 ) :
350
+ idx [self ._order [ii ]] = self ._data_idx [ii ]
351
+ vdata = self ._data [idx ].ravel ()
356
352
vdata = np .concatenate ((vdata , [vdata [- 1 ]]))
357
- self ._volume_ax_objs ['patch' ].set_x (self ._data_idx ['v' ] - 0.5 )
353
+ self ._volume_ax_objs ['patch' ].set_x (self ._data_idx [3 ] - 0.5 )
358
354
self ._volume_ax_objs ['step' ].set_ydata (vdata )
359
355
if notify :
360
356
self ._notify_links ()
361
357
self ._changing = False
362
358
363
359
# Matplotlib handlers ####################################################
364
360
def _in_axis (self , event ):
365
- """Return axis key if within one of our axes, else None"""
361
+ """Return axis index if within one of our axes, else None"""
366
362
if getattr (event , 'inaxes' ) is None :
367
363
return None
368
- for key , ax in self ._axes . items ( ):
364
+ for ii , ax in enumerate ( self ._axes ):
369
365
if event .inaxes is ax :
370
- return key
366
+ return ii
371
367
372
368
def _on_scroll (self , event ):
373
369
"""Handle mpl scroll wheel event"""
374
370
assert event .button in ('up' , 'down' )
375
- key = self ._in_axis (event )
376
- if key is None :
371
+ ii = self ._in_axis (event )
372
+ if ii is None :
377
373
return
378
374
if event .key is not None and 'shift' in event .key :
379
375
if self .n_volumes <= 1 :
380
376
return
381
- key = 'v' # shift: change volume in any axis
382
- assert key in [ 'x' , 'y' , 'z' , 'v' ]
377
+ ii = 3 # shift: change volume in any axis
378
+ assert ii in range ( 4 )
383
379
dv = 10. if event .key is not None and 'control' in event .key else 1.
384
380
dv *= 1. if event .button == 'up' else - 1.
385
- dv *= - 1 if self ._flips . get ( key , False ) else 1
386
- val = self ._data_idx [key ] + dv
387
- if key == 'v' :
381
+ dv *= - 1 if self ._flips [ ii ] else 1
382
+ val = self ._data_idx [ii ] + dv
383
+ if ii == 3 :
388
384
self ._set_volume_index (val )
389
385
else :
390
- coords = {key : val }
391
- for k in 'xyz' :
392
- if k not in coords :
393
- coords [k ] = self ._data_idx [k ]
394
- coords = np .array ([coords ['x' ], coords ['y' ], coords ['z' ], 1. ])
395
- coords = np .dot (self ._affine , coords )[:3 ]
396
- self ._set_position (coords [0 ], coords [1 ], coords [2 ])
386
+ coords = [self ._data_idx [k ] for k in range (3 )] + [1. ]
387
+ coords [ii ] = val
388
+ self ._set_position (* np .dot (self ._affine , coords )[:3 ])
397
389
self ._draw ()
398
390
399
391
def _on_mouse (self , event ):
400
392
"""Handle mpl mouse move and button press events"""
401
393
if event .button != 1 : # only enabled while dragging
402
394
return
403
- key = self ._in_axis (event )
404
- if key is None :
395
+ ii = self ._in_axis (event )
396
+ if ii is None :
405
397
return
406
- if key == 'v' :
398
+ if ii == 3 :
407
399
# volume plot directly translates
408
400
self ._set_volume_index (event .xdata )
409
401
else :
410
402
# translate click xdata/ydata to physical position
411
- xax , yax = dict ( x = 'yz' , y = 'xz' , z = 'xy' )[ key ]
403
+ xax , yax = [[ 1 , 2 ], [ 0 , 2 ], [ 0 , 1 ]][ ii ]
412
404
x , y = event .xdata , event .ydata
413
405
x = self ._sizes [xax ] - x if self ._flips [xax ] else x
414
406
y = self ._sizes [yax ] - y if self ._flips [yax ] else y
415
- idxs = {xax : x , yax : y , key : self ._data_idx [key ]}
416
- idxs = np .array ([idxs ['x' ], idxs ['y' ], idxs ['z' ], 1. ])
417
- pos = np .dot (self ._affine , idxs )[:3 ]
418
- self ._set_position (* pos )
407
+ idxs = [None , None , None , 1. ]
408
+ idxs [xax ] = x
409
+ idxs [yax ] = y
410
+ idxs [ii ] = self ._data_idx [ii ]
411
+ self ._set_position (* np .dot (self ._affine , idxs )[:3 ])
419
412
self ._draw ()
420
413
421
414
def _on_keypress (self , event ):
@@ -425,14 +418,14 @@ def _on_keypress(self, event):
425
418
426
419
def _draw (self ):
427
420
"""Update all four (or three) plots"""
428
- for key in 'xyz' :
429
- ax , im = self ._axes [key ], self . _ims [ key ]
430
- ax .draw_artist (im )
431
- for line in self ._crosshairs [key ].values ():
421
+ for ii in range ( 3 ) :
422
+ ax = self ._axes [ii ]
423
+ ax .draw_artist (self . _ims [ ii ] )
424
+ for line in self ._crosshairs [ii ].values ():
432
425
ax .draw_artist (line )
433
426
ax .figure .canvas .blit (ax .bbox )
434
- if self .n_volumes > 1 and 'v' in self ._axes : # user might only pass 3
435
- ax = self ._axes ['v' ]
427
+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
428
+ ax = self ._axes [3 ]
436
429
ax .draw_artist (ax .patch ) # axis bgcolor to erase old lines
437
430
for key in ('step' , 'patch' ):
438
431
ax .draw_artist (self ._volume_ax_objs [key ])
0 commit comments