55
66import numpy as np
77from matplotlib import cm , colors , pyplot as plt
8+ from matplotlib .axes import Axes , SubplotBase
89from matplotlib .collections import LineCollection
9- from matplotlib .colors import LinearSegmentedColormap
10+ from matplotlib .colors import Colormap , LinearSegmentedColormap
1011from matplotlib .figure import Figure
11- from matplotlib .pyplot import axis , figure
1212from mpl_toolkits .axes_grid1 import make_axes_locatable
1313from numpy import ndarray
1414
@@ -51,7 +51,8 @@ def _normalize_scale(attr: ndarray, scale_factor: float):
5151 warnings .warn (
5252 "Attempting to normalize by value approximately 0, visualized results"
5353 "may be misleading. This likely means that attribution values are all"
54- "close to 0."
54+ "close to 0." ,
55+ stacklevel = 2 ,
5556 )
5657 attr_norm = attr / scale_factor
5758 return np .clip (attr_norm , - 1 , 1 )
@@ -80,37 +81,39 @@ def _normalize_attr(
8081
8182 # Choose appropriate signed values and rescale, removing given outlier percentage.
8283 if VisualizeSign [sign ] == VisualizeSign .all :
83- threshold = _cumulative_sum_threshold (np .abs (attr_combined ), 100 - outlier_perc )
84+ threshold = _cumulative_sum_threshold (
85+ np .abs (attr_combined ), 100.0 - outlier_perc
86+ )
8487 elif VisualizeSign [sign ] == VisualizeSign .positive :
8588 attr_combined = (attr_combined > 0 ) * attr_combined
86- threshold = _cumulative_sum_threshold (attr_combined , 100 - outlier_perc )
89+ threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
8790 elif VisualizeSign [sign ] == VisualizeSign .negative :
8891 attr_combined = (attr_combined < 0 ) * attr_combined
8992 threshold = - 1 * _cumulative_sum_threshold (
90- np .abs (attr_combined ), 100 - outlier_perc
93+ np .abs (attr_combined ), 100.0 - outlier_perc
9194 )
9295 elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
9396 attr_combined = np .abs (attr_combined )
94- threshold = _cumulative_sum_threshold (attr_combined , 100 - outlier_perc )
97+ threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
9598 else :
9699 raise AssertionError ("Visualize Sign type is not valid." )
97100 return _normalize_scale (attr_combined , threshold )
98101
99102
100103def visualize_image_attr (
101104 attr : ndarray ,
102- original_image : Union [ None , ndarray ] = None ,
105+ original_image : Optional [ ndarray ] = None ,
103106 method : str = "heat_map" ,
104107 sign : str = "absolute_value" ,
105- plt_fig_axis : Union [ None , Tuple [ figure , axis ]] = None ,
108+ plt_fig_axis : Optional [ Tuple [ Figure , Union [ ndarray , SubplotBase , Axes ] ]] = None ,
106109 outlier_perc : Union [int , float ] = 2 ,
107- cmap : Union [None , str ] = None ,
110+ cmap : Optional [ Union [str , Colormap ] ] = None ,
108111 alpha_overlay : float = 0.5 ,
109112 show_colorbar : bool = False ,
110- title : Union [ None , str ] = None ,
113+ title : Optional [ str ] = None ,
111114 fig_size : Tuple [int , int ] = (6 , 6 ),
112115 use_pyplot : bool = True ,
113- ):
116+ ) -> Tuple [ Figure , Axes ] :
114117 r"""
115118 Visualizes attribution for a given image by normalizing attribution values
116119 of the desired sign (positive, negative, absolute value, or all) and displaying
@@ -264,8 +267,8 @@ def visualize_image_attr(
264267
265268 # Set default colormap and bounds based on sign.
266269 if VisualizeSign [sign ] == VisualizeSign .all :
267- default_cmap = LinearSegmentedColormap . from_list (
268- "RdWhGn" , ["red" , "white" , "green" ]
270+ default_cmap : Union [ str , LinearSegmentedColormap ] = (
271+ LinearSegmentedColormap . from_list ( "RdWhGn" , ["red" , "white" , "green" ])
269272 )
270273 vmin , vmax = - 1 , 1
271274 elif VisualizeSign [sign ] == VisualizeSign .positive :
@@ -345,7 +348,7 @@ def visualize_image_attr_multiple(
345348 original_image : Union [None , ndarray ],
346349 methods : List [str ],
347350 signs : List [str ],
348- titles : Union [ None , List [str ]] = None ,
351+ titles : Optional [ List [str ]] = None ,
349352 fig_size : Tuple [int , int ] = (8 , 6 ),
350353 use_pyplot : bool = True ,
351354 ** kwargs : Any ,
@@ -418,22 +421,25 @@ def visualize_image_attr_multiple(
418421 "If titles list is given, length must " "match that of methods list."
419422 )
420423 if use_pyplot :
421- plt_fig = plt .figure (figsize = fig_size )
424+ plt_fig , plt_axis = plt .subplots (figsize = fig_size )
422425 else :
423- plt_fig = Figure (figsize = fig_size )
424- plt_axis = plt_fig .subplots (1 , len (methods ))
426+ plt_fig : Figure = Figure (figsize = fig_size )
427+ plt_axis : Union [ Axes , List [ Axes ]] = plt_fig .subplots (1 , len (methods ))
425428
429+ plt_axis_list : List [Axes ] = []
426430 # When visualizing one
427431 if len (methods ) == 1 :
428- plt_axis = [plt_axis ]
432+ plt_axis_list = [plt_axis ]
433+ else :
434+ plt_axis_list = plt_axis
429435
430436 for i in range (len (methods )):
431437 visualize_image_attr (
432438 attr ,
433439 original_image = original_image ,
434440 method = methods [i ],
435441 sign = signs [i ],
436- plt_fig_axis = (plt_fig , plt_axis [i ]),
442+ plt_fig_axis = (plt_fig , plt_axis_list [i ]),
437443 use_pyplot = False ,
438444 title = titles [i ] if titles else None ,
439445 ** kwargs ,
@@ -452,12 +458,12 @@ def visualize_timeseries_attr(
452458 sign : str = "absolute_value" ,
453459 channel_labels : Optional [List [str ]] = None ,
454460 channels_last : bool = True ,
455- plt_fig_axis : Union [ None , Tuple [ figure , axis ]] = None ,
461+ plt_fig_axis : Optional [ Tuple [ Figure , Union [ ndarray , SubplotBase , Axes ] ]] = None ,
456462 outlier_perc : Union [int , float ] = 2 ,
457- cmap : Union [None , str ] = None ,
463+ cmap : Optional [ Union [str , Colormap ] ] = None ,
458464 alpha_overlay : float = 0.7 ,
459465 show_colorbar : bool = False ,
460- title : Union [ None , str ] = None ,
466+ title : Optional [ str ] = None ,
461467 fig_size : Tuple [int , int ] = (6 , 6 ),
462468 use_pyplot : bool = True ,
463469 ** pyplot_kwargs ,
@@ -596,7 +602,8 @@ def visualize_timeseries_attr(
596602 if num_channels > timeseries_length :
597603 warnings .warn (
598604 "Number of channels ({}) greater than time series length ({}), "
599- "please verify input format" .format (num_channels , timeseries_length )
605+ "please verify input format" .format (num_channels , timeseries_length ),
606+ stacklevel = 2 ,
600607 )
601608
602609 num_subplots = num_channels
@@ -627,14 +634,16 @@ def visualize_timeseries_attr(
627634 plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True )
628635
629636 if not isinstance (plt_axis , ndarray ):
630- plt_axis = np .array ([plt_axis ])
637+ plt_axis_list = np .array ([plt_axis ])
638+ else :
639+ plt_axis_list = plt_axis
631640
632641 norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
633642
634643 # Set default colormap and bounds based on sign.
635644 if VisualizeSign [sign ] == VisualizeSign .all :
636- default_cmap = LinearSegmentedColormap . from_list (
637- "RdWhGn" , ["red" , "white" , "green" ]
645+ default_cmap : Union [ str , LinearSegmentedColormap ] = (
646+ LinearSegmentedColormap . from_list ( "RdWhGn" , ["red" , "white" , "green" ])
638647 )
639648 vmin , vmax = - 1 , 1
640649 elif VisualizeSign [sign ] == VisualizeSign .positive :
@@ -649,7 +658,7 @@ def visualize_timeseries_attr(
649658 else :
650659 raise AssertionError ("Visualize Sign type is not valid." )
651660 cmap = cmap if cmap is not None else default_cmap
652- cmap = cm .get_cmap (cmap )
661+ cmap = plt .get_cmap (cmap )
653662 cm_norm = colors .Normalize (vmin , vmax )
654663
655664 def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ):
@@ -673,11 +682,11 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
673682
674683 for chan in range (num_channels ):
675684
676- plt_axis [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
685+ plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
677686 if channel_labels is not None :
678- plt_axis [chan ].set_ylabel (channel_labels [chan ])
687+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
679688
680- _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis [chan ])
689+ _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis_list [chan ])
681690
682691 plt .subplots_adjust (hspace = 0 )
683692
@@ -687,16 +696,16 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
687696 ):
688697
689698 # Dark colors are better in this case
690- cycler = plt .cycler ("color" , cm . Dark2 .colors )
691- plt_axis [0 ].set_prop_cycle (cycler )
699+ cycler = plt .cycler ("color" , plt . get_cmap ( " Dark2" ) .colors )
700+ plt_axis_list [0 ].set_prop_cycle (cycler )
692701
693702 for chan in range (num_channels ):
694703 label = channel_labels [chan ] if channel_labels else None
695- plt_axis [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
704+ plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
696705
697- _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis [0 ])
706+ _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis_list [0 ])
698707
699- plt_axis [0 ].legend (loc = "best" )
708+ plt_axis_list [0 ].legend (loc = "best" )
700709
701710 elif (
702711 TimeseriesVisualizationMethod [method ]
@@ -710,12 +719,12 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
710719
711720 lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
712721 lc .set_array (norm_attr [chan , :])
713- plt_axis [chan ].add_collection (lc )
714- plt_axis [chan ].set_ylim (
722+ plt_axis_list [chan ].add_collection (lc )
723+ plt_axis_list [chan ].set_ylim (
715724 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
716725 )
717726 if channel_labels is not None :
718- plt_axis [chan ].set_ylabel (channel_labels [chan ])
727+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
719728
720729 plt .subplots_adjust (hspace = 0 )
721730
@@ -725,7 +734,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
725734 plt .xlim ([x_values [0 ], x_values [- 1 ]])
726735
727736 if show_colorbar :
728- axis_separator = make_axes_locatable (plt_axis [- 1 ])
737+ axis_separator = make_axes_locatable (plt_axis_list [- 1 ])
729738 colorbar_axis = axis_separator .append_axes ("bottom" , size = "5%" , pad = 0.4 )
730739 colorbar_alpha = alpha_overlay
731740 if (
@@ -740,7 +749,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
740749 alpha = colorbar_alpha ,
741750 )
742751 if title :
743- plt_axis [0 ].set_title (title )
752+ plt_axis_list [0 ].set_title (title )
744753
745754 if use_pyplot :
746755 plt .show ()
0 commit comments