Skip to content

Commit 925bd27

Browse files
jjunchofacebook-github-bot
authored andcommitted
Reducing pyre-fixme's in visualization.py
Summary: This diff helps address the number of pyre-fixme's and return type annotation pyre errors in the visualizations.py file Differential Revision: D64405790
1 parent 749f804 commit 925bd27

File tree

1 file changed

+34
-52
lines changed

1 file changed

+34
-52
lines changed

captum/attr/_utils/visualization.py

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from matplotlib.image import AxesImage
1717
from mpl_toolkits.axes_grid1 import make_axes_locatable
1818
from numpy import ndarray
19+
from torch import Tensor
1920

2021
try:
2122
from IPython.display import display, HTML
@@ -46,13 +47,11 @@ class VisualizeSign(Enum):
4647
all = 4
4748

4849

49-
# pyre-fixme[3]: Return type must be annotated.
50-
def _prepare_image(attr_visual: ndarray):
50+
def _prepare_image(attr_visual: ndarray) -> ndarray:
5151
return np.clip(attr_visual.astype(int), 0, 255)
5252

5353

54-
# pyre-fixme[3]: Return type must be annotated.
55-
def _normalize_scale(attr: ndarray, scale_factor: float):
54+
def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
5655
assert scale_factor != 0, "Cannot normalize by scale factor = 0"
5756
if abs(scale_factor) < 1e-5:
5857
warnings.warn(
@@ -65,8 +64,7 @@ def _normalize_scale(attr: ndarray, scale_factor: float):
6564
return np.clip(attr_norm, -1, 1)
6665

6766

68-
# pyre-fixme[3]: Return type must be annotated.
69-
def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
67+
def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) -> float:
7068
# given values should be non-negative
7169
assert percentile >= 0 and percentile <= 100, (
7270
"Percentile for thresholding must be " "between 0 and 100 inclusive."
@@ -77,13 +75,12 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
7775
return sorted_vals[threshold_id]
7876

7977

80-
# pyre-fixme[3]: Return type must be annotated.
8178
def _normalize_attr(
8279
attr: ndarray,
8380
sign: str,
8481
outlier_perc: Union[int, float] = 2,
8582
reduction_axis: Optional[int] = None,
86-
):
83+
) -> ndarray:
8784
attr_combined = attr
8885
if reduction_axis is not None:
8986
attr_combined = np.sum(attr, axis=reduction_axis)
@@ -370,8 +367,7 @@ def visualize_image_attr(
370367

371368
heat_map: Optional[AxesImage] = None
372369

373-
# pyre-ignore[33]: prohibited Any
374-
visualization_methods: Dict[str, Callable[..., Any]] = {
370+
visualization_methods: Dict[str, Callable[..., Union[None, AxesImage]]] = {
375371
"heat_map": _visualize_heat_map,
376372
"blended_heat_map": _visualize_blended_heat_map,
377373
"masked_image": _visualize_masked_image,
@@ -420,7 +416,6 @@ def visualize_image_attr(
420416
return plt_fig, plt_axis
421417

422418

423-
# pyre-fixme[3]: Return type must be annotated.
424419
def visualize_image_attr_multiple(
425420
attr: ndarray,
426421
original_image: Union[None, ndarray],
@@ -430,7 +425,7 @@ def visualize_image_attr_multiple(
430425
fig_size: Tuple[int, int] = (8, 6),
431426
use_pyplot: bool = True,
432427
**kwargs: Any,
433-
):
428+
) -> Tuple[Figure, Axes]:
434429
r"""
435430
Visualizes attribution using multiple visualization methods displayed
436431
in a 1 x k grid, where k is the number of desired visualizations.
@@ -530,7 +525,6 @@ def visualize_image_attr_multiple(
530525
return plt_fig, plt_axis
531526

532527

533-
# pyre-fixme[3]: Return type must be annotated.
534528
def visualize_timeseries_attr(
535529
attr: ndarray,
536530
data: ndarray,
@@ -547,9 +541,8 @@ def visualize_timeseries_attr(
547541
title: Optional[str] = None,
548542
fig_size: Tuple[int, int] = (6, 6),
549543
use_pyplot: bool = True,
550-
# pyre-fixme[2]: Parameter must be annotated.
551-
**pyplot_kwargs,
552-
):
544+
**pyplot_kwargs: Any,
545+
) -> Tuple[Figure, Axes]:
553546
r"""
554547
Visualizes attribution for a given timeseries data by normalizing
555548
attribution values of the desired sign (positive, negative, absolute value,
@@ -667,11 +660,9 @@ def visualize_timeseries_attr(
667660

668661
# Check input dimensions
669662
assert len(attr.shape) == 2, "Expected attr of shape (N, C), got {}".format(
670-
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
671663
attr.shape
672664
)
673665
assert len(data.shape) == 2, "Expected data of shape (N, C), got {}".format(
674-
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
675666
attr.shape
676667
)
677668

@@ -747,9 +738,8 @@ def visualize_timeseries_attr(
747738
cm_norm = colors.Normalize(vmin, vmax)
748739

749740
# pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
750-
# pyre-fixme[3]: Return type must be annotated.
751741
# pyre-fixme[2]: Parameter must be annotated.
752-
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
742+
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax) -> None:
753743
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
754744
half_col_width = (x_values[1] - x_values[0]) / 2.0
755745
for icol, col_center in enumerate(x_vals):
@@ -863,39 +853,31 @@ class VisualizationDataRecord:
863853

864854
def __init__(
865855
self,
866-
# pyre-fixme[2]: Parameter must be annotated.
867-
word_attributions,
868-
# pyre-fixme[2]: Parameter must be annotated.
869-
pred_prob,
870-
# pyre-fixme[2]: Parameter must be annotated.
871-
pred_class,
872-
# pyre-fixme[2]: Parameter must be annotated.
873-
true_class,
874-
# pyre-fixme[2]: Parameter must be annotated.
875-
attr_class,
876-
# pyre-fixme[2]: Parameter must be annotated.
877-
attr_score,
878-
# pyre-fixme[2]: Parameter must be annotated.
879-
raw_input_ids,
880-
# pyre-fixme[2]: Parameter must be annotated.
881-
convergence_score,
856+
word_attributions: Tensor,
857+
pred_prob: float,
858+
pred_class: int,
859+
true_class: int,
860+
attr_class: int,
861+
attr_score: float,
862+
raw_input_ids: List[str],
863+
convergence_score: float,
882864
) -> None:
883-
# pyre-fixme[4]: Attribute must be annotated.
884-
self.word_attributions = word_attributions
885-
# pyre-fixme[4]: Attribute must be annotated.
886-
self.pred_prob = pred_prob
887-
# pyre-fixme[4]: Attribute must be annotated.
888-
self.pred_class = pred_class
889-
# pyre-fixme[4]: Attribute must be annotated.
890-
self.true_class = true_class
891-
# pyre-fixme[4]: Attribute must be annotated.
892-
self.attr_class = attr_class
893-
# pyre-fixme[4]: Attribute must be annotated.
894-
self.attr_score = attr_score
895-
# pyre-fixme[4]: Attribute must be annotated.
896-
self.raw_input_ids = raw_input_ids
897-
# pyre-fixme[4]: Attribute must be annotated.
898-
self.convergence_score = convergence_score
865+
866+
self.word_attributions: Tensor = word_attributions
867+
868+
self.pred_prob: float = pred_prob
869+
870+
self.pred_class: int = pred_class
871+
872+
self.true_class: int = true_class
873+
874+
self.attr_class: int = attr_class
875+
876+
self.attr_score: float = attr_score
877+
878+
self.raw_input_ids: List[str] = raw_input_ids
879+
880+
self.convergence_score: float = convergence_score
899881

900882

901883
def _get_color(attr: int) -> str:

0 commit comments

Comments
 (0)