33# pyre-strict
44import warnings
55from enum import Enum
6- from typing import Any , Iterable , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
77
88import matplotlib
99
1313from matplotlib .collections import LineCollection
1414from matplotlib .colors import Colormap , LinearSegmentedColormap
1515from matplotlib .figure import Figure
16+ from matplotlib .image import AxesImage
1617from mpl_toolkits .axes_grid1 import make_axes_locatable
1718from numpy import ndarray
1819
@@ -88,26 +89,129 @@ def _normalize_attr(
8889 attr_combined = np .sum (attr , axis = reduction_axis )
8990
9091 # Choose appropriate signed values and rescale, removing given outlier percentage.
91- if VisualizeSign [sign ] == VisualizeSign .all :
92+ if VisualizeSign [sign ]. value == VisualizeSign .all . value :
9293 threshold = _cumulative_sum_threshold (
9394 np .abs (attr_combined ), 100.0 - outlier_perc
9495 )
95- elif VisualizeSign [sign ] == VisualizeSign .positive :
96+ elif VisualizeSign [sign ]. value == VisualizeSign .positive . value :
9697 attr_combined = (attr_combined > 0 ) * attr_combined
9798 threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
98- elif VisualizeSign [sign ] == VisualizeSign .negative :
99+ elif VisualizeSign [sign ]. value == VisualizeSign .negative . value :
99100 attr_combined = (attr_combined < 0 ) * attr_combined
100101 threshold = - 1 * _cumulative_sum_threshold (
101102 np .abs (attr_combined ), 100.0 - outlier_perc
102103 )
103- elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
104+ elif VisualizeSign [sign ]. value == VisualizeSign .absolute_value . value :
104105 attr_combined = np .abs (attr_combined )
105106 threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
106107 else :
107108 raise AssertionError ("Visualize Sign type is not valid." )
108109 return _normalize_scale (attr_combined , threshold )
109110
110111
112+ def _initialize_cmap_and_vmin_vmax (
113+ sign : str ,
114+ ) -> Tuple [Union [str , Colormap ], float , float ]:
115+ if VisualizeSign [sign ].value == VisualizeSign .all .value :
116+ default_cmap : Union [str , LinearSegmentedColormap ] = (
117+ LinearSegmentedColormap .from_list ("RdWhGn" , ["red" , "white" , "green" ])
118+ )
119+ vmin , vmax = - 1 , 1
120+ elif VisualizeSign [sign ].value == VisualizeSign .positive .value :
121+ default_cmap = "Greens"
122+ vmin , vmax = 0 , 1
123+ elif VisualizeSign [sign ].value == VisualizeSign .negative .value :
124+ default_cmap = "Reds"
125+ vmin , vmax = 0 , 1
126+ elif VisualizeSign [sign ].value == VisualizeSign .absolute_value .value :
127+ default_cmap = "Blues"
128+ vmin , vmax = 0 , 1
129+ else :
130+ raise AssertionError ("Visualize Sign type is not valid." )
131+ return default_cmap , vmin , vmax
132+
133+
134+ def _visualize_original_image (
135+ plt_axis : Axes ,
136+ original_image : Optional [ndarray ],
137+ ** kwargs : Any ,
138+ ) -> None :
139+ assert (
140+ original_image is not None
141+ ), "Original image expected for original_image method."
142+ if len (original_image .shape ) > 2 and original_image .shape [2 ] == 1 :
143+ original_image = np .squeeze (original_image , axis = 2 )
144+ plt_axis .imshow (original_image )
145+
146+
147+ def _visualize_heat_map (
148+ plt_axis : Axes ,
149+ norm_attr : ndarray ,
150+ cmap : Union [str , Colormap ],
151+ vmin : float ,
152+ vmax : float ,
153+ ** kwargs : Any ,
154+ ) -> AxesImage :
155+ heat_map = plt_axis .imshow (norm_attr , cmap = cmap , vmin = vmin , vmax = vmax )
156+ return heat_map
157+
158+
159+ def _visualize_blended_heat_map (
160+ plt_axis : Axes ,
161+ original_image : ndarray ,
162+ norm_attr : ndarray ,
163+ cmap : Union [str , Colormap ],
164+ vmin : float ,
165+ vmax : float ,
166+ alpha_overlay : float ,
167+ ** kwargs : Any ,
168+ ) -> AxesImage :
169+ assert (
170+ original_image is not None
171+ ), "Original Image expected for blended_heat_map method."
172+ plt_axis .imshow (np .mean (original_image , axis = 2 ), cmap = "gray" )
173+ heat_map = plt_axis .imshow (
174+ norm_attr , cmap = cmap , vmin = vmin , vmax = vmax , alpha = alpha_overlay
175+ )
176+ return heat_map
177+
178+
179+ def _visualize_masked_image (
180+ plt_axis : Axes ,
181+ sign : str ,
182+ original_image : ndarray ,
183+ norm_attr : ndarray ,
184+ ** kwargs : Any ,
185+ ) -> None :
186+ assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
187+ "Cannot display masked image with both positive and negative "
188+ "attributions, choose a different sign option."
189+ )
190+ plt_axis .imshow (_prepare_image (original_image * np .expand_dims (norm_attr , 2 )))
191+
192+
193+ def _visualize_alpha_scaling (
194+ plt_axis : Axes ,
195+ sign : str ,
196+ original_image : ndarray ,
197+ norm_attr : ndarray ,
198+ ** kwargs : Any ,
199+ ) -> None :
200+ assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
201+ "Cannot display alpha scaling with both positive and negative "
202+ "attributions, choose a different sign option."
203+ )
204+ plt_axis .imshow (
205+ np .concatenate (
206+ [
207+ original_image ,
208+ _prepare_image (np .expand_dims (norm_attr , 2 ) * 255 ),
209+ ],
210+ axis = 2 ,
211+ )
212+ )
213+
214+
111215def visualize_image_attr (
112216 attr : ndarray ,
113217 original_image : Optional [ndarray ] = None ,
@@ -242,15 +346,18 @@ def visualize_image_attr(
242346 plt_fig , plt_axis = plt .subplots (figsize = fig_size )
243347 else :
244348 plt_fig = Figure (figsize = fig_size )
245- plt_axis = plt_fig .subplots () # type: ignore
349+ plt_axis = plt_fig .subplots ()
246350 # Figure.subplots returns Axes or array of Axes
247351
248352 if original_image is not None :
249353 if np .max (original_image ) <= 1.0 :
250354 original_image = _prepare_image (original_image * 255 )
251- elif ImageVisualizationMethod [method ] != ImageVisualizationMethod .heat_map :
355+ elif (
356+ ImageVisualizationMethod [method ].value
357+ != ImageVisualizationMethod .heat_map .value
358+ ):
252359 raise ValueError (
253- "Original Image must be provided for"
360+ "Original Image must be provided for "
254361 "any visualization other than heatmap."
255362 )
256363
@@ -261,76 +368,37 @@ def visualize_image_attr(
261368 plt_axis .set_xticklabels ([])
262369 plt_axis .grid (visible = False )
263370
264- heat_map = None
265- # Show original image
266- if ImageVisualizationMethod [method ] == ImageVisualizationMethod .original_image :
267- assert (
268- original_image is not None
269- ), "Original image expected for original_image method."
270- if len (original_image .shape ) > 2 and original_image .shape [2 ] == 1 :
271- original_image = np .squeeze (original_image , axis = 2 )
272- plt_axis .imshow (original_image )
273- else :
274- # Choose appropriate signed attributions and normalize.
275- norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = 2 )
371+ heat_map : Optional [AxesImage ] = None
276372
277- # Set default colormap and bounds based on sign.
278- if VisualizeSign [sign ] == VisualizeSign .all :
279- default_cmap : Union [str , LinearSegmentedColormap ] = (
280- LinearSegmentedColormap .from_list ("RdWhGn" , ["red" , "white" , "green" ])
281- )
282- vmin , vmax = - 1 , 1
283- elif VisualizeSign [sign ] == VisualizeSign .positive :
284- default_cmap = "Greens"
285- vmin , vmax = 0 , 1
286- elif VisualizeSign [sign ] == VisualizeSign .negative :
287- default_cmap = "Reds"
288- vmin , vmax = 0 , 1
289- elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
290- default_cmap = "Blues"
291- vmin , vmax = 0 , 1
292- else :
293- raise AssertionError ("Visualize Sign type is not valid." )
294- cmap = cmap if cmap is not None else default_cmap
295-
296- # Show appropriate image visualization.
297- if ImageVisualizationMethod [method ] == ImageVisualizationMethod .heat_map :
298- heat_map = plt_axis .imshow (norm_attr , cmap = cmap , vmin = vmin , vmax = vmax )
299- elif (
300- ImageVisualizationMethod [method ]
301- == ImageVisualizationMethod .blended_heat_map
302- ):
303- assert (
304- original_image is not None
305- ), "Original Image expected for blended_heat_map method."
306- plt_axis .imshow (np .mean (original_image , axis = 2 ), cmap = "gray" )
307- heat_map = plt_axis .imshow (
308- norm_attr , cmap = cmap , vmin = vmin , vmax = vmax , alpha = alpha_overlay
309- )
310- elif ImageVisualizationMethod [method ] == ImageVisualizationMethod .masked_image :
311- assert VisualizeSign [sign ] != VisualizeSign .all , (
312- "Cannot display masked image with both positive and negative "
313- "attributions, choose a different sign option."
314- )
315- plt_axis .imshow (
316- _prepare_image (original_image * np .expand_dims (norm_attr , 2 ))
317- )
318- elif ImageVisualizationMethod [method ] == ImageVisualizationMethod .alpha_scaling :
319- assert VisualizeSign [sign ] != VisualizeSign .all , (
320- "Cannot display alpha scaling with both positive and negative "
321- "attributions, choose a different sign option."
322- )
323- plt_axis .imshow (
324- np .concatenate (
325- [
326- original_image ,
327- _prepare_image (np .expand_dims (norm_attr , 2 ) * 255 ),
328- ],
329- axis = 2 ,
330- )
331- )
332- else :
333- raise AssertionError ("Visualize Method type is not valid." )
373+ # pyre-ignore[33]: prohibited Any
374+ visualization_methods : Dict [str , Callable [..., Any ]] = {
375+ "heat_map" : _visualize_heat_map ,
376+ "blended_heat_map" : _visualize_blended_heat_map ,
377+ "masked_image" : _visualize_masked_image ,
378+ "alpha_scaling" : _visualize_alpha_scaling ,
379+ "original_image" : _visualize_original_image ,
380+ }
381+ # Choose appropriate signed attributions and normalize.
382+ norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = 2 )
383+
384+ # Set default colormap and bounds based on sign.
385+ default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
386+ cmap = cmap if cmap is not None else default_cmap
387+
388+ kwargs = {
389+ "plt_axis" : plt_axis ,
390+ "original_image" : original_image ,
391+ "sign" : sign ,
392+ "cmap" : cmap ,
393+ "alpha_overlay" : alpha_overlay ,
394+ "vmin" : vmin ,
395+ "vmax" : vmax ,
396+ "norm_attr" : norm_attr ,
397+ }
398+ if method in visualization_methods :
399+ heat_map = visualization_methods [method ](** kwargs )
400+ else :
401+ raise AssertionError ("Visualize Method type is not valid." )
334402
335403 # Add colorbar. If given method is not a heatmap and no colormap is relevant,
336404 # then a colormap axis is created and hidden. This is necessary for appropriate
0 commit comments