33from enum import Enum
44from typing import Any , Iterable , List , Optional , Tuple , Union
55
6+ import matplotlib
7+
68import numpy as np
79from matplotlib import cm , colors , pyplot as plt
10+ from matplotlib .axes import Axes
811from matplotlib .collections import LineCollection
9- from matplotlib .colors import LinearSegmentedColormap
12+ from matplotlib .colors import Colormap , LinearSegmentedColormap
1013from matplotlib .figure import Figure
11- from matplotlib .pyplot import axis , figure
1214from mpl_toolkits .axes_grid1 import make_axes_locatable
1315from numpy import ndarray
1416
@@ -51,7 +53,8 @@ def _normalize_scale(attr: ndarray, scale_factor: float):
5153 warnings .warn (
5254 "Attempting to normalize by value approximately 0, visualized results"
5355 "may be misleading. This likely means that attribution values are all"
54- "close to 0."
56+ "close to 0." ,
57+ stacklevel = 2 ,
5558 )
5659 attr_norm = attr / scale_factor
5760 return np .clip (attr_norm , - 1 , 1 )
@@ -80,37 +83,39 @@ def _normalize_attr(
8083
8184 # Choose appropriate signed values and rescale, removing given outlier percentage.
8285 if VisualizeSign [sign ] == VisualizeSign .all :
83- threshold = _cumulative_sum_threshold (np .abs (attr_combined ), 100 - outlier_perc )
86+ threshold = _cumulative_sum_threshold (
87+ np .abs (attr_combined ), 100.0 - outlier_perc
88+ )
8489 elif VisualizeSign [sign ] == VisualizeSign .positive :
8590 attr_combined = (attr_combined > 0 ) * attr_combined
86- threshold = _cumulative_sum_threshold (attr_combined , 100 - outlier_perc )
91+ threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
8792 elif VisualizeSign [sign ] == VisualizeSign .negative :
8893 attr_combined = (attr_combined < 0 ) * attr_combined
8994 threshold = - 1 * _cumulative_sum_threshold (
90- np .abs (attr_combined ), 100 - outlier_perc
95+ np .abs (attr_combined ), 100.0 - outlier_perc
9196 )
9297 elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
9398 attr_combined = np .abs (attr_combined )
94- threshold = _cumulative_sum_threshold (attr_combined , 100 - outlier_perc )
99+ threshold = _cumulative_sum_threshold (attr_combined , 100.0 - outlier_perc )
95100 else :
96101 raise AssertionError ("Visualize Sign type is not valid." )
97102 return _normalize_scale (attr_combined , threshold )
98103
99104
100105def visualize_image_attr (
101106 attr : ndarray ,
102- original_image : Union [ None , ndarray ] = None ,
107+ original_image : Optional [ ndarray ] = None ,
103108 method : str = "heat_map" ,
104109 sign : str = "absolute_value" ,
105- plt_fig_axis : Union [ None , Tuple [figure , axis ]] = None ,
110+ plt_fig_axis : Optional [ Tuple [Figure , Axes ]] = None ,
106111 outlier_perc : Union [int , float ] = 2 ,
107- cmap : Union [None , str ] = None ,
112+ cmap : Optional [ Union [str , Colormap ] ] = None ,
108113 alpha_overlay : float = 0.5 ,
109114 show_colorbar : bool = False ,
110- title : Union [ None , str ] = None ,
115+ title : Optional [ str ] = None ,
111116 fig_size : Tuple [int , int ] = (6 , 6 ),
112117 use_pyplot : bool = True ,
113- ):
118+ ) -> Tuple [ Figure , Axes ] :
114119 r"""
115120 Visualizes attribution for a given image by normalizing attribution values
116121 of the desired sign (positive, negative, absolute value, or all) and displaying
@@ -231,7 +236,8 @@ def visualize_image_attr(
231236 plt_fig , plt_axis = plt .subplots (figsize = fig_size )
232237 else :
233238 plt_fig = Figure (figsize = fig_size )
234- plt_axis = plt_fig .subplots ()
239+ plt_axis = plt_fig .subplots () # type: ignore
240+ # Figure.subplots returns Axes or array of Axes
235241
236242 if original_image is not None :
237243 if np .max (original_image ) <= 1.0 :
@@ -264,8 +270,8 @@ def visualize_image_attr(
264270
265271 # Set default colormap and bounds based on sign.
266272 if VisualizeSign [sign ] == VisualizeSign .all :
267- default_cmap = LinearSegmentedColormap . from_list (
268- "RdWhGn" , ["red" , "white" , "green" ]
273+ default_cmap : Union [ str , LinearSegmentedColormap ] = (
274+ LinearSegmentedColormap . from_list ( "RdWhGn" , ["red" , "white" , "green" ])
269275 )
270276 vmin , vmax = - 1 , 1
271277 elif VisualizeSign [sign ] == VisualizeSign .positive :
@@ -345,7 +351,7 @@ def visualize_image_attr_multiple(
345351 original_image : Union [None , ndarray ],
346352 methods : List [str ],
347353 signs : List [str ],
348- titles : Union [ None , List [str ]] = None ,
354+ titles : Optional [ List [str ]] = None ,
349355 fig_size : Tuple [int , int ] = (8 , 6 ),
350356 use_pyplot : bool = True ,
351357 ** kwargs : Any ,
@@ -423,17 +429,22 @@ def visualize_image_attr_multiple(
423429 plt_fig = Figure (figsize = fig_size )
424430 plt_axis = plt_fig .subplots (1 , len (methods ))
425431
432+ plt_axis_list : List [Axes ] = []
426433 # When visualizing one
427434 if len (methods ) == 1 :
428- plt_axis = [plt_axis ]
435+ plt_axis_list = [plt_axis ] # type: ignore
436+ # Figure.subplots returns Axes or array of Axes
437+ else :
438+ plt_axis_list = plt_axis # type: ignore
439+ # Figure.subplots returns Axes or array of Axes
429440
430441 for i in range (len (methods )):
431442 visualize_image_attr (
432443 attr ,
433444 original_image = original_image ,
434445 method = methods [i ],
435446 sign = signs [i ],
436- plt_fig_axis = (plt_fig , plt_axis [i ]),
447+ plt_fig_axis = (plt_fig , plt_axis_list [i ]),
437448 use_pyplot = False ,
438449 title = titles [i ] if titles else None ,
439450 ** kwargs ,
@@ -452,12 +463,12 @@ def visualize_timeseries_attr(
452463 sign : str = "absolute_value" ,
453464 channel_labels : Optional [List [str ]] = None ,
454465 channels_last : bool = True ,
455- plt_fig_axis : Union [ None , Tuple [ figure , axis ]] = None ,
466+ plt_fig_axis : Optional [ Tuple [ Figure , Union [ Axes , List [ Axes ]] ]] = None ,
456467 outlier_perc : Union [int , float ] = 2 ,
457- cmap : Union [None , str ] = None ,
468+ cmap : Optional [ Union [str , Colormap ] ] = None ,
458469 alpha_overlay : float = 0.7 ,
459470 show_colorbar : bool = False ,
460- title : Union [ None , str ] = None ,
471+ title : Optional [ str ] = None ,
461472 fig_size : Tuple [int , int ] = (6 , 6 ),
462473 use_pyplot : bool = True ,
463474 ** pyplot_kwargs ,
@@ -596,7 +607,8 @@ def visualize_timeseries_attr(
596607 if num_channels > timeseries_length :
597608 warnings .warn (
598609 "Number of channels ({}) greater than time series length ({}), "
599- "please verify input format" .format (num_channels , timeseries_length )
610+ "please verify input format" .format (num_channels , timeseries_length ),
611+ stacklevel = 2 ,
600612 )
601613
602614 num_subplots = num_channels
@@ -624,17 +636,20 @@ def visualize_timeseries_attr(
624636 )
625637 else :
626638 plt_fig = Figure (figsize = fig_size )
627- plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True )
639+ plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True ) # type: ignore
640+ # Figure.subplots returns Axes or array of Axes
628641
629642 if not isinstance (plt_axis , ndarray ):
630- plt_axis = np .array ([plt_axis ])
643+ plt_axis_list = np .array ([plt_axis ])
644+ else :
645+ plt_axis_list = plt_axis
631646
632647 norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
633648
634649 # Set default colormap and bounds based on sign.
635650 if VisualizeSign [sign ] == VisualizeSign .all :
636- default_cmap = LinearSegmentedColormap . from_list (
637- "RdWhGn" , ["red" , "white" , "green" ]
651+ default_cmap : Union [ str , LinearSegmentedColormap ] = (
652+ LinearSegmentedColormap . from_list ( "RdWhGn" , ["red" , "white" , "green" ])
638653 )
639654 vmin , vmax = - 1 , 1
640655 elif VisualizeSign [sign ] == VisualizeSign .positive :
@@ -653,7 +668,6 @@ def visualize_timeseries_attr(
653668 cm_norm = colors .Normalize (vmin , vmax )
654669
655670 def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ):
656-
657671 half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
658672 for icol , col_center in enumerate (x_vals ):
659673 left = col_center - half_col_width
@@ -670,52 +684,47 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
670684 TimeseriesVisualizationMethod [method ]
671685 == TimeseriesVisualizationMethod .overlay_individual
672686 ):
673-
674687 for chan in range (num_channels ):
675-
676- plt_axis [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
688+ plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
677689 if channel_labels is not None :
678- plt_axis [chan ].set_ylabel (channel_labels [chan ])
690+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
679691
680- _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis [chan ])
692+ _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis_list [chan ])
681693
682694 plt .subplots_adjust (hspace = 0 )
683695
684696 elif (
685697 TimeseriesVisualizationMethod [method ]
686698 == TimeseriesVisualizationMethod .overlay_combined
687699 ):
688-
689700 # Dark colors are better in this case
690- cycler = plt .cycler ("color" , cm . Dark2 . colors )
691- plt_axis [0 ].set_prop_cycle (cycler )
701+ cycler = plt .cycler ("color" , matplotlib . colormaps [ " Dark2" ]) # type: ignore
702+ plt_axis_list [0 ].set_prop_cycle (cycler )
692703
693704 for chan in range (num_channels ):
694705 label = channel_labels [chan ] if channel_labels else None
695- plt_axis [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
706+ plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
696707
697- _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis [0 ])
708+ _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis_list [0 ])
698709
699- plt_axis [0 ].legend (loc = "best" )
710+ plt_axis_list [0 ].legend (loc = "best" )
700711
701712 elif (
702713 TimeseriesVisualizationMethod [method ]
703714 == TimeseriesVisualizationMethod .colored_graph
704715 ):
705-
706716 for chan in range (num_channels ):
707-
708717 points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
709718 segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
710719
711720 lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
712721 lc .set_array (norm_attr [chan , :])
713- plt_axis [chan ].add_collection (lc )
714- plt_axis [chan ].set_ylim (
722+ plt_axis_list [chan ].add_collection (lc )
723+ plt_axis_list [chan ].set_ylim (
715724 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
716725 )
717726 if channel_labels is not None :
718- plt_axis [chan ].set_ylabel (channel_labels [chan ])
727+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
719728
720729 plt .subplots_adjust (hspace = 0 )
721730
@@ -725,7 +734,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
725734 plt .xlim ([x_values [0 ], x_values [- 1 ]])
726735
727736 if show_colorbar :
728- axis_separator = make_axes_locatable (plt_axis [- 1 ])
737+ axis_separator = make_axes_locatable (plt_axis_list [- 1 ])
729738 colorbar_axis = axis_separator .append_axes ("bottom" , size = "5%" , pad = 0.4 )
730739 colorbar_alpha = alpha_overlay
731740 if (
@@ -740,7 +749,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
740749 alpha = colorbar_alpha ,
741750 )
742751 if title :
743- plt_axis [0 ].set_title (title )
752+ plt_axis_list [0 ].set_title (title )
744753
745754 if use_pyplot :
746755 plt .show ()
0 commit comments