Skip to content

Commit 01ae0df

Browse files
Typing fix
Differential Revision: D58703398
1 parent 3f0cd93 commit 01ae0df

File tree

5 files changed

+68
-52
lines changed

5 files changed

+68
-52
lines changed

captum/attr/_core/noise_tunnel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
from enum import Enum
3-
from typing import Any, cast, List, Tuple, Union
3+
from typing import Any, cast, List, Optional, Tuple, Union
44

55
import torch
66
from captum._utils.common import (
@@ -80,7 +80,7 @@ def attribute(
8080
inputs: Union[Tensor, Tuple[Tensor, ...]],
8181
nt_type: str = "smoothgrad",
8282
nt_samples: int = 5,
83-
nt_samples_batch_size: int = None,
83+
nt_samples_batch_size: Optional[int] = None,
8484
stdevs: Union[float, Tuple[float, ...]] = 1.0,
8585
draw_baseline_from_distrib: bool = False,
8686
**kwargs: Any,

captum/attr/_utils/interpretable_input.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import ABC, abstractmethod
12
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
23

34
import torch
@@ -45,7 +46,7 @@ def _scatter_itp_attr_by_mask(
4546
return attr
4647

4748

48-
class InterpretableInput:
49+
class InterpretableInput(ABC):
4950
"""
5051
InterpretableInput is an adapter for different kinds of model inputs to
5152
work in Captum's attribution methods. Generally, attribution methods of Captum
@@ -94,6 +95,7 @@ class to create other types of customized input.
9495
is only allowed in certain attribution classes like LLMAttribution for now.)
9596
"""
9697

98+
@abstractmethod
9799
def to_tensor(self) -> Tensor:
98100
"""
99101
Return the interpretable representation of this input as a tensor
@@ -104,6 +106,7 @@ def to_tensor(self) -> Tensor:
104106
"""
105107
pass
106108

109+
@abstractmethod
107110
def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any:
108111
"""
109112
Get the (perturbed) input in the format required by the model

captum/attr/_utils/visualization.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from enum import Enum
44
from typing import Any, Iterable, List, Optional, Tuple, Union
55

6+
import matplotlib
7+
68
import numpy as np
79
from matplotlib import cm, colors, pyplot as plt
10+
from matplotlib.axes import Axes
811
from matplotlib.collections import LineCollection
9-
from matplotlib.colors import LinearSegmentedColormap
12+
from matplotlib.colors import Colormap, LinearSegmentedColormap
1013
from matplotlib.figure import Figure
11-
from matplotlib.pyplot import axis, figure
1214
from mpl_toolkits.axes_grid1 import make_axes_locatable
1315
from numpy import ndarray
1416

@@ -51,7 +53,8 @@ def _normalize_scale(attr: ndarray, scale_factor: float):
5153
warnings.warn(
5254
"Attempting to normalize by value approximately 0, visualized results"
5355
"may be misleading. This likely means that attribution values are all"
54-
"close to 0."
56+
"close to 0.",
57+
stacklevel=2,
5558
)
5659
attr_norm = attr / scale_factor
5760
return np.clip(attr_norm, -1, 1)
@@ -80,37 +83,39 @@ def _normalize_attr(
8083

8184
# Choose appropriate signed values and rescale, removing given outlier percentage.
8285
if VisualizeSign[sign] == VisualizeSign.all:
83-
threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
86+
threshold = _cumulative_sum_threshold(
87+
np.abs(attr_combined), 100.0 - outlier_perc
88+
)
8489
elif VisualizeSign[sign] == VisualizeSign.positive:
8590
attr_combined = (attr_combined > 0) * attr_combined
86-
threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
91+
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
8792
elif VisualizeSign[sign] == VisualizeSign.negative:
8893
attr_combined = (attr_combined < 0) * attr_combined
8994
threshold = -1 * _cumulative_sum_threshold(
90-
np.abs(attr_combined), 100 - outlier_perc
95+
np.abs(attr_combined), 100.0 - outlier_perc
9196
)
9297
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
9398
attr_combined = np.abs(attr_combined)
94-
threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
99+
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
95100
else:
96101
raise AssertionError("Visualize Sign type is not valid.")
97102
return _normalize_scale(attr_combined, threshold)
98103

99104

100105
def visualize_image_attr(
101106
attr: ndarray,
102-
original_image: Union[None, ndarray] = None,
107+
original_image: Optional[ndarray] = None,
103108
method: str = "heat_map",
104109
sign: str = "absolute_value",
105-
plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
110+
plt_fig_axis: Optional[Tuple[Figure, Axes]] = None,
106111
outlier_perc: Union[int, float] = 2,
107-
cmap: Union[None, str] = None,
112+
cmap: Optional[Union[str, Colormap]] = None,
108113
alpha_overlay: float = 0.5,
109114
show_colorbar: bool = False,
110-
title: Union[None, str] = None,
115+
title: Optional[str] = None,
111116
fig_size: Tuple[int, int] = (6, 6),
112117
use_pyplot: bool = True,
113-
):
118+
) -> Tuple[Figure, Axes]:
114119
r"""
115120
Visualizes attribution for a given image by normalizing attribution values
116121
of the desired sign (positive, negative, absolute value, or all) and displaying
@@ -231,7 +236,8 @@ def visualize_image_attr(
231236
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
232237
else:
233238
plt_fig = Figure(figsize=fig_size)
234-
plt_axis = plt_fig.subplots()
239+
plt_axis = plt_fig.subplots() # type: ignore
240+
# Figure.subplots returns Axes or array of Axes
235241

236242
if original_image is not None:
237243
if np.max(original_image) <= 1.0:
@@ -264,8 +270,8 @@ def visualize_image_attr(
264270

265271
# Set default colormap and bounds based on sign.
266272
if VisualizeSign[sign] == VisualizeSign.all:
267-
default_cmap = LinearSegmentedColormap.from_list(
268-
"RdWhGn", ["red", "white", "green"]
273+
default_cmap: Union[str, LinearSegmentedColormap] = (
274+
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
269275
)
270276
vmin, vmax = -1, 1
271277
elif VisualizeSign[sign] == VisualizeSign.positive:
@@ -345,7 +351,7 @@ def visualize_image_attr_multiple(
345351
original_image: Union[None, ndarray],
346352
methods: List[str],
347353
signs: List[str],
348-
titles: Union[None, List[str]] = None,
354+
titles: Optional[List[str]] = None,
349355
fig_size: Tuple[int, int] = (8, 6),
350356
use_pyplot: bool = True,
351357
**kwargs: Any,
@@ -423,17 +429,22 @@ def visualize_image_attr_multiple(
423429
plt_fig = Figure(figsize=fig_size)
424430
plt_axis = plt_fig.subplots(1, len(methods))
425431

432+
plt_axis_list: List[Axes] = []
426433
# When visualizing one
427434
if len(methods) == 1:
428-
plt_axis = [plt_axis]
435+
plt_axis_list = [plt_axis] # type: ignore
436+
# Figure.subplots returns Axes or array of Axes
437+
else:
438+
plt_axis_list = plt_axis # type: ignore
439+
# Figure.subplots returns Axes or array of Axes
429440

430441
for i in range(len(methods)):
431442
visualize_image_attr(
432443
attr,
433444
original_image=original_image,
434445
method=methods[i],
435446
sign=signs[i],
436-
plt_fig_axis=(plt_fig, plt_axis[i]),
447+
plt_fig_axis=(plt_fig, plt_axis_list[i]),
437448
use_pyplot=False,
438449
title=titles[i] if titles else None,
439450
**kwargs,
@@ -452,12 +463,12 @@ def visualize_timeseries_attr(
452463
sign: str = "absolute_value",
453464
channel_labels: Optional[List[str]] = None,
454465
channels_last: bool = True,
455-
plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
466+
plt_fig_axis: Optional[Tuple[Figure, Union[Axes, List[Axes]]]] = None,
456467
outlier_perc: Union[int, float] = 2,
457-
cmap: Union[None, str] = None,
468+
cmap: Optional[Union[str, Colormap]] = None,
458469
alpha_overlay: float = 0.7,
459470
show_colorbar: bool = False,
460-
title: Union[None, str] = None,
471+
title: Optional[str] = None,
461472
fig_size: Tuple[int, int] = (6, 6),
462473
use_pyplot: bool = True,
463474
**pyplot_kwargs,
@@ -596,7 +607,8 @@ def visualize_timeseries_attr(
596607
if num_channels > timeseries_length:
597608
warnings.warn(
598609
"Number of channels ({}) greater than time series length ({}), "
599-
"please verify input format".format(num_channels, timeseries_length)
610+
"please verify input format".format(num_channels, timeseries_length),
611+
stacklevel=2,
600612
)
601613

602614
num_subplots = num_channels
@@ -624,17 +636,20 @@ def visualize_timeseries_attr(
624636
)
625637
else:
626638
plt_fig = Figure(figsize=fig_size)
627-
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True)
639+
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore
640+
# Figure.subplots returns Axes or array of Axes
628641

629642
if not isinstance(plt_axis, ndarray):
630-
plt_axis = np.array([plt_axis])
643+
plt_axis_list = np.array([plt_axis])
644+
else:
645+
plt_axis_list = plt_axis
631646

632647
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)
633648

634649
# Set default colormap and bounds based on sign.
635650
if VisualizeSign[sign] == VisualizeSign.all:
636-
default_cmap = LinearSegmentedColormap.from_list(
637-
"RdWhGn", ["red", "white", "green"]
651+
default_cmap: Union[str, LinearSegmentedColormap] = (
652+
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
638653
)
639654
vmin, vmax = -1, 1
640655
elif VisualizeSign[sign] == VisualizeSign.positive:
@@ -653,7 +668,6 @@ def visualize_timeseries_attr(
653668
cm_norm = colors.Normalize(vmin, vmax)
654669

655670
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
656-
657671
half_col_width = (x_values[1] - x_values[0]) / 2.0
658672
for icol, col_center in enumerate(x_vals):
659673
left = col_center - half_col_width
@@ -670,52 +684,47 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
670684
TimeseriesVisualizationMethod[method]
671685
== TimeseriesVisualizationMethod.overlay_individual
672686
):
673-
674687
for chan in range(num_channels):
675-
676-
plt_axis[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
688+
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
677689
if channel_labels is not None:
678-
plt_axis[chan].set_ylabel(channel_labels[chan])
690+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
679691

680-
_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis[chan])
692+
_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan])
681693

682694
plt.subplots_adjust(hspace=0)
683695

684696
elif (
685697
TimeseriesVisualizationMethod[method]
686698
== TimeseriesVisualizationMethod.overlay_combined
687699
):
688-
689700
# Dark colors are better in this case
690-
cycler = plt.cycler("color", cm.Dark2.colors)
691-
plt_axis[0].set_prop_cycle(cycler)
701+
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore
702+
plt_axis_list[0].set_prop_cycle(cycler)
692703

693704
for chan in range(num_channels):
694705
label = channel_labels[chan] if channel_labels else None
695-
plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
706+
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
696707

697-
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0])
708+
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0])
698709

699-
plt_axis[0].legend(loc="best")
710+
plt_axis_list[0].legend(loc="best")
700711

701712
elif (
702713
TimeseriesVisualizationMethod[method]
703714
== TimeseriesVisualizationMethod.colored_graph
704715
):
705-
706716
for chan in range(num_channels):
707-
708717
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
709718
segments = np.concatenate([points[:-1], points[1:]], axis=1)
710719

711720
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
712721
lc.set_array(norm_attr[chan, :])
713-
plt_axis[chan].add_collection(lc)
714-
plt_axis[chan].set_ylim(
722+
plt_axis_list[chan].add_collection(lc)
723+
plt_axis_list[chan].set_ylim(
715724
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
716725
)
717726
if channel_labels is not None:
718-
plt_axis[chan].set_ylabel(channel_labels[chan])
727+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
719728

720729
plt.subplots_adjust(hspace=0)
721730

@@ -725,7 +734,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
725734
plt.xlim([x_values[0], x_values[-1]])
726735

727736
if show_colorbar:
728-
axis_separator = make_axes_locatable(plt_axis[-1])
737+
axis_separator = make_axes_locatable(plt_axis_list[-1])
729738
colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.4)
730739
colorbar_alpha = alpha_overlay
731740
if (
@@ -740,7 +749,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
740749
alpha=colorbar_alpha,
741750
)
742751
if title:
743-
plt_axis[0].set_title(title)
752+
plt_axis_list[0].set_title(title)
744753

745754
if use_pyplot:
746755
plt.show()

captum/insights/attr_vis/features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from captum._utils.common import safe_div
99
from captum.attr._utils import visualization as viz
1010
from captum.insights.attr_vis._utils.transforms import format_transforms
11+
from matplotlib.figure import Figure
1112
from torch import Tensor
1213

14+
1315
FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution")
1416

1517

16-
def _convert_figure_base64(fig) -> str:
18+
def _convert_figure_base64(fig: Figure) -> str:
1719
buff = BytesIO()
1820
with warnings.catch_warnings():
1921
warnings.simplefilter("ignore")

tests/insights/test_contribution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python3
22

33
import unittest
4-
from typing import Callable, List, Union
4+
from typing import Any, Callable, Generator, List, Tuple, Union
55

66
import torch
77
import torch.nn as nn
88
from captum.insights import AttributionVisualizer, Batch
99
from captum.insights.attr_vis.app import FilterConfig
1010
from captum.insights.attr_vis.features import BaseFeature, FeatureOutput, ImageFeature
1111
from tests.helpers import BaseTest
12+
from torch import Tensor
1213

1314

1415
class RealFeature(BaseFeature):
@@ -26,7 +27,8 @@ def __init__(
2627
visualization_transform=None,
2728
)
2829

29-
def visualization_type(self) -> str:
30+
@staticmethod
31+
def visualization_type() -> str:
3032
return "real"
3133

3234
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
@@ -103,7 +105,7 @@ def _labelled_img_data(
103105
height: int = 8,
104106
depth: int = 3,
105107
num_labels: int = 10,
106-
):
108+
) -> Generator[Tuple[Tensor, Tensor], Any, Any]:
107109
for _ in range(num_samples):
108110
yield torch.empty(depth, height, width).uniform_(0, 1), torch.randint(
109111
num_labels, (1,)

0 commit comments

Comments
 (0)