1616from matplotlib .image import AxesImage
1717from mpl_toolkits .axes_grid1 import make_axes_locatable
1818from numpy import ndarray
19+ from torch import Tensor
1920
2021try :
2122 from IPython .display import display , HTML
@@ -46,13 +47,11 @@ class VisualizeSign(Enum):
4647 all = 4
4748
4849
49- # pyre-fixme[3]: Return type must be annotated.
50- def _prepare_image (attr_visual : ndarray ):
50+ def _prepare_image (attr_visual : ndarray ) -> ndarray :
5151 return np .clip (attr_visual .astype (int ), 0 , 255 )
5252
5353
54- # pyre-fixme[3]: Return type must be annotated.
55- def _normalize_scale (attr : ndarray , scale_factor : float ):
54+ def _normalize_scale (attr : ndarray , scale_factor : float ) -> ndarray :
5655 assert scale_factor != 0 , "Cannot normalize by scale factor = 0"
5756 if abs (scale_factor ) < 1e-5 :
5857 warnings .warn (
@@ -65,8 +64,7 @@ def _normalize_scale(attr: ndarray, scale_factor: float):
6564 return np .clip (attr_norm , - 1 , 1 )
6665
6766
68- # pyre-fixme[3]: Return type must be annotated.
69- def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]):
67+ def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]) -> float :
7068 # given values should be non-negative
7169 assert percentile >= 0 and percentile <= 100 , (
7270 "Percentile for thresholding must be " "between 0 and 100 inclusive."
@@ -77,13 +75,12 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
7775 return sorted_vals [threshold_id ]
7876
7977
80- # pyre-fixme[3]: Return type must be annotated.
8178def _normalize_attr (
8279 attr : ndarray ,
8380 sign : str ,
8481 outlier_perc : Union [int , float ] = 2 ,
8582 reduction_axis : Optional [int ] = None ,
86- ):
83+ ) -> ndarray :
8784 attr_combined = attr
8885 if reduction_axis is not None :
8986 attr_combined = np .sum (attr , axis = reduction_axis )
@@ -370,8 +367,7 @@ def visualize_image_attr(
370367
371368 heat_map : Optional [AxesImage ] = None
372369
373- # pyre-ignore[33]: prohibited Any
374- visualization_methods : Dict [str , Callable [..., Any ]] = {
370+ visualization_methods : Dict [str , Callable [..., Union [None , AxesImage ]]] = {
375371 "heat_map" : _visualize_heat_map ,
376372 "blended_heat_map" : _visualize_blended_heat_map ,
377373 "masked_image" : _visualize_masked_image ,
@@ -420,7 +416,6 @@ def visualize_image_attr(
420416 return plt_fig , plt_axis
421417
422418
423- # pyre-fixme[3]: Return type must be annotated.
424419def visualize_image_attr_multiple (
425420 attr : ndarray ,
426421 original_image : Union [None , ndarray ],
@@ -430,7 +425,7 @@ def visualize_image_attr_multiple(
430425 fig_size : Tuple [int , int ] = (8 , 6 ),
431426 use_pyplot : bool = True ,
432427 ** kwargs : Any ,
433- ):
428+ ) -> Tuple [ Figure , Axes ] :
434429 r"""
435430 Visualizes attribution using multiple visualization methods displayed
436431 in a 1 x k grid, where k is the number of desired visualizations.
@@ -530,7 +525,6 @@ def visualize_image_attr_multiple(
530525 return plt_fig , plt_axis
531526
532527
533- # pyre-fixme[3]: Return type must be annotated.
534528def visualize_timeseries_attr (
535529 attr : ndarray ,
536530 data : ndarray ,
@@ -547,9 +541,8 @@ def visualize_timeseries_attr(
547541 title : Optional [str ] = None ,
548542 fig_size : Tuple [int , int ] = (6 , 6 ),
549543 use_pyplot : bool = True ,
550- # pyre-fixme[2]: Parameter must be annotated.
551- ** pyplot_kwargs ,
552- ):
544+ ** pyplot_kwargs : Any ,
545+ ) -> Tuple [Figure , Axes ]:
553546 r"""
554547 Visualizes attribution for a given timeseries data by normalizing
555548 attribution values of the desired sign (positive, negative, absolute value,
@@ -667,11 +660,9 @@ def visualize_timeseries_attr(
667660
668661 # Check input dimensions
669662 assert len (attr .shape ) == 2 , "Expected attr of shape (N, C), got {}" .format (
670- # pyre-fixme[16]: Module `attr` has no attribute `shape`.
671663 attr .shape
672664 )
673665 assert len (data .shape ) == 2 , "Expected data of shape (N, C), got {}" .format (
674- # pyre-fixme[16]: Module `attr` has no attribute `shape`.
675666 attr .shape
676667 )
677668
@@ -747,9 +738,8 @@ def visualize_timeseries_attr(
747738 cm_norm = colors .Normalize (vmin , vmax )
748739
749740 # pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
750- # pyre-fixme[3]: Return type must be annotated.
751741 # pyre-fixme[2]: Parameter must be annotated.
752- def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ):
742+ def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ) -> None :
753743 # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
754744 half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
755745 for icol , col_center in enumerate (x_vals ):
@@ -863,39 +853,31 @@ class VisualizationDataRecord:
863853
864854 def __init__ (
865855 self ,
866- # pyre-fixme[2]: Parameter must be annotated.
867- word_attributions ,
868- # pyre-fixme[2]: Parameter must be annotated.
869- pred_prob ,
870- # pyre-fixme[2]: Parameter must be annotated.
871- pred_class ,
872- # pyre-fixme[2]: Parameter must be annotated.
873- true_class ,
874- # pyre-fixme[2]: Parameter must be annotated.
875- attr_class ,
876- # pyre-fixme[2]: Parameter must be annotated.
877- attr_score ,
878- # pyre-fixme[2]: Parameter must be annotated.
879- raw_input_ids ,
880- # pyre-fixme[2]: Parameter must be annotated.
881- convergence_score ,
856+ word_attributions : Tensor ,
857+ pred_prob : float ,
858+ pred_class : int ,
859+ true_class : int ,
860+ attr_class : int ,
861+ attr_score : float ,
862+ raw_input_ids : List [str ],
863+ convergence_score : float ,
882864 ) -> None :
883- # pyre-fixme[4]: Attribute must be annotated.
884- self .word_attributions = word_attributions
885- # pyre-fixme[4]: Attribute must be annotated.
886- self .pred_prob = pred_prob
887- # pyre-fixme[4]: Attribute must be annotated.
888- self .pred_class = pred_class
889- # pyre-fixme[4]: Attribute must be annotated.
890- self .true_class = true_class
891- # pyre-fixme[4]: Attribute must be annotated.
892- self .attr_class = attr_class
893- # pyre-fixme[4]: Attribute must be annotated.
894- self .attr_score = attr_score
895- # pyre-fixme[4]: Attribute must be annotated.
896- self .raw_input_ids = raw_input_ids
897- # pyre-fixme[4]: Attribute must be annotated.
898- self .convergence_score = convergence_score
865+
866+ self .word_attributions : Tensor = word_attributions
867+
868+ self .pred_prob : float = pred_prob
869+
870+ self .pred_class : int = pred_class
871+
872+ self .true_class : int = true_class
873+
874+ self .attr_class : int = attr_class
875+
876+ self .attr_score : float = attr_score
877+
878+ self .raw_input_ids : List [ str ] = raw_input_ids
879+
880+ self .convergence_score : float = convergence_score
899881
900882
901883def _get_color (attr : int ) -> str :
0 commit comments