Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 324 additions & 6 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
import warnings
from enum import Enum
from typing import Any, Iterable, List, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm, colors, pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.figure import Figure
from matplotlib.pyplot import axis, figure
Expand All @@ -27,6 +28,12 @@ class ImageVisualizationMethod(Enum):
alpha_scaling = 5


class TimeseriesVisualizationMethod(Enum):
overlay_individual = 1
overlay_combined = 2
colored_graph = 3


class VisualizeSign(Enum):
positive = 1
absolute_value = 2
Expand Down Expand Up @@ -61,10 +68,16 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
return sorted_vals[threshold_id]


def _normalize_image_attr(
attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2
def _normalize_attr(
attr: ndarray,
sign: str,
outlier_perc: Union[int, float] = 2,
reduction_axis: Optional[int] = None,
):
attr_combined = np.sum(attr, axis=2)
attr_combined = attr
if reduction_axis is not None:
attr_combined = np.sum(attr, axis=reduction_axis)

# Choose appropriate signed values and rescale, removing given outlier percentage.
if VisualizeSign[sign] == VisualizeSign.all:
threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
Expand Down Expand Up @@ -241,7 +254,7 @@ def visualize_image_attr(
plt_axis.imshow(original_image)
else:
# Choose appropriate signed attributions and normalize.
norm_attr = _normalize_image_attr(attr, sign, outlier_perc)
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)

# Set default colormap and bounds based on sign.
if VisualizeSign[sign] == VisualizeSign.all:
Expand Down Expand Up @@ -422,6 +435,311 @@ def visualize_image_attr_multiple(
return plt_fig, plt_axis


def visualize_timeseries_attr(
attr: ndarray,
data: ndarray,
x_values: Optional[ndarray] = None,
method: str = "individual_channels",
sign: str = "absolute_value",
channel_labels: Optional[List[str]] = None,
channels_last: bool = True,
plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
outlier_perc: Union[int, float] = 2,
cmap: Union[None, str] = None,
alpha_overlay: float = 0.7,
show_colorbar: bool = False,
title: Union[None, str] = None,
fig_size: Tuple[int, int] = (6, 6),
use_pyplot: bool = True,
**pyplot_kwargs,
):
r"""
Visualizes attribution for a given timeseries data by normalizing
attribution values of the desired sign (positive, negative, absolute value,
or all) and displaying them using the desired mode in a matplotlib figure.

Args:

attr (numpy.array): Numpy array corresponding to attributions to be
visualized. Shape must be in the form (N, C) with channels
as last dimension, unless `channels_last` is set to True.
Shape must also match that of the timeseries data.
data (numpy.array): Numpy array corresponding to the original,
equidistant timeseries data. Shape must be in the form
(N, C) with channels as last dimension, unless
`channels_last` is set to true.
x_values (numpy.array, optional): Numpy array corresponding to the
points on the x-axis. Shape must be in the form (N, ). If
not provided, integers from 0 to N-1 are used.
Default: None
method (string, optional): Chosen method for visualizing attributions
overlaid onto data. Supported options are:

1. `overlay_individual` - Plot each channel individually in
a separate panel, and overlay the attributions for each
channel as a heat map. The `alpha_overlay` parameter
controls the alpha of the heat map.

2. `overlay_combined` - Plot all channels in the same panel,
and overlay the average attributions as a heat map.

3. `colored_graph` - Plot each channel in a separate panel,
and color the graphs according to the attribution
values. Works best with color maps that does not contain
white or very bright colors.
Default: `overlay_individual`
sign (string, optional): Chosen sign of attributions to visualize.
Supported options are:

1. `positive` - Displays only positive pixel attributions.

2. `absolute_value` - Displays absolute value of
attributions.

3. `negative` - Displays only negative pixel attributions.

4. `all` - Displays both positive and negative attribution
values.
Default: `absolute_value`
channel_labels (list of strings, optional): List of labels
corresponding to each channel in data.
Default: None
channels_last (bool, optional): If True, data is expected to have
channels as the last dimension, i.e. (N, C). If False, data
is expected to have channels first, i.e. (C, N).
Default: True
plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
on which to visualize. If None is provided, then a new figure
and axis are created.
Default: None
outlier_perc (float or int, optional): Top attribution values which
correspond to a total of outlier_perc percentage of the
total attribution are set to 1 and scaling is performed
using the minimum of these values. For sign=`all`, outliers
and scale value are computed using absolute value of
attributions.
Default: 2
cmap (string, optional): String corresponding to desired colormap for
heatmap visualization. This defaults to "Reds" for negative
sign, "Blues" for absolute value, "Greens" for positive sign,
and a spectrum from red to green for all. Note that this
argument is only used for visualizations displaying heatmaps.
Default: None
alpha_overlay (float, optional): Alpha to set for heatmap when using
`blended_heat_map` visualization mode, which overlays the
heat map over the greyscaled original image.
Default: 0.7
show_colorbar (boolean): Displays colorbar for heat map below
the visualization.
title (string, optional): Title string for plot. If None, no title is
set.
Default: None
fig_size (tuple, optional): Size of figure created.
Default: (6,6)
use_pyplot (boolean): If true, uses pyplot to create and show
figure and displays the figure after creating. If False,
uses Matplotlib object oriented API and simply returns a
figure object without showing.
Default: True.
pyplot_kwargs: Keyword arguments forwarded to plt.plot, for example
`linewidth=3`, `color='black'`, etc

Returns:
2-element tuple of **figure**, **axis**:
- **figure** (*matplotlib.pyplot.figure*):
Figure object on which visualization
is created. If plt_fig_axis argument is given, this is the
same figure provided.
- **axis** (*matplotlib.pyplot.axis*):
Axis object on which visualization
is created. If plt_fig_axis argument is given, this is the
same axis provided.

Examples::

>>> # Classifier takes input of shape (batch, length, channels)
>>> model = Classifier()
>>> dl = DeepLift(model)
>>> attribution = dl.attribute(data, target=0)
>>> # Pick the first sample and plot each channel in data in a separate
>>> # panel, with attributions overlaid
>>> visualize_timeseries_attr(attribution[0], data[0], "overlay_individual")
"""

# Check input dimensions
assert len(attr.shape) == 2, "Expected attr of shape (N, C), got {}".format(
attr.shape
)
assert len(data.shape) == 2, "Expected data of shape (N, C), got {}".format(
attr.shape
)

# Convert to channels-first
if channels_last:
attr = np.transpose(attr)
data = np.transpose(data)

num_channels = attr.shape[0]
timeseries_length = attr.shape[1]

if num_channels > timeseries_length:
warnings.warn(
"Number of channels ({}) greater than time series length ({}), "
"please verify input format".format(num_channels, timeseries_length)
)

num_subplots = num_channels
if (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_combined
):
num_subplots = 1
attr = np.sum(attr, axis=0) # Merge attributions across channels

if x_values is not None:
assert (
x_values.shape[0] == timeseries_length
), "x_values must have same length as data"
else:
x_values = np.arange(timeseries_length)

# Create plot if figure, axis not provided
if plt_fig_axis is not None:
plt_fig, plt_axis = plt_fig_axis
else:
if use_pyplot:
plt_fig, plt_axis = plt.subplots(
figsize=fig_size, nrows=num_subplots, sharex=True
)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True)

if not isinstance(plt_axis, ndarray):
plt_axis = np.array([plt_axis])

norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)

# Set default colormap and bounds based on sign.
if VisualizeSign[sign] == VisualizeSign.all:
default_cmap = LinearSegmentedColormap.from_list(
"RdWhGn", ["red", "white", "green"]
)
vmin, vmax = -1, 1
elif VisualizeSign[sign] == VisualizeSign.positive:
default_cmap = "Greens"
vmin, vmax = 0, 1
elif VisualizeSign[sign] == VisualizeSign.negative:
default_cmap = "Reds"
vmin, vmax = 0, 1
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
default_cmap = "Blues"
vmin, vmax = 0, 1
else:
raise AssertionError("Visualize Sign type is not valid.")
cmap = cmap if cmap is not None else default_cmap
cmap = cm.get_cmap(cmap)
cm_norm = colors.Normalize(vmin, vmax)

def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):

half_col_width = (x_values[1] - x_values[0]) / 2.0
for icol, col_center in enumerate(x_vals):
left = col_center - half_col_width
right = col_center + half_col_width
ax.axvspan(
xmin=left,
xmax=right,
facecolor=(cmap(cm_norm(attr_vals[icol]))),
edgecolor=None,
alpha=alpha_overlay,
)

if (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_individual
):

for chan in range(num_channels):

plt_axis[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
if channel_labels is not None:
plt_axis[chan].set_ylabel(channel_labels[chan])

_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis[chan])

plt.subplots_adjust(hspace=0)

elif (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.overlay_combined
):

# Dark colors are better in this case
cycler = plt.cycler("color", cm.Dark2.colors)
plt_axis[0].set_prop_cycle(cycler)

for chan in range(num_channels):
if channel_labels is not None:
label = channel_labels[chan]
else:
label = None
plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)

_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0])

plt_axis[0].legend(loc="best")

elif (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.colored_graph
):

for chan in range(num_channels):

points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
lc.set_array(norm_attr[chan, :])
plt_axis[chan].add_collection(lc)
plt_axis[chan].set_ylim(
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
)
if channel_labels is not None:
plt_axis[chan].set_ylabel(channel_labels[chan])

plt.subplots_adjust(hspace=0)

else:
raise AssertionError("Invalid visualization method: {}".format(method))

plt.xlim([x_values[0], x_values[-1]])

if show_colorbar:
axis_separator = make_axes_locatable(plt_axis[-1])
colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.4)
colorbar_alpha = alpha_overlay
if (
TimeseriesVisualizationMethod[method]
== TimeseriesVisualizationMethod.colored_graph
):
colorbar_alpha = 1.0
plt_fig.colorbar(
cm.ScalarMappable(cm_norm, cmap),
orientation="horizontal",
cax=colorbar_axis,
alpha=colorbar_alpha,
)
if title:
plt_axis[0].set_title(title)

if use_pyplot:
plt.show()

return plt_fig, plt_axis


# These visualization methods are for text and are partially copied from
# experiments conducted by Davide Testuggine at Facebook.

Expand Down