diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 2db9026872..0cfada9b7b 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 import warnings from enum import Enum -from typing import Any, Iterable, List, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union import numpy as np -from matplotlib import pyplot as plt +from matplotlib import cm, colors, pyplot as plt +from matplotlib.collections import LineCollection from matplotlib.colors import LinearSegmentedColormap from matplotlib.figure import Figure from matplotlib.pyplot import axis, figure @@ -27,6 +28,12 @@ class ImageVisualizationMethod(Enum): alpha_scaling = 5 +class TimeseriesVisualizationMethod(Enum): + overlay_individual = 1 + overlay_combined = 2 + colored_graph = 3 + + class VisualizeSign(Enum): positive = 1 absolute_value = 2 @@ -61,10 +68,16 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]): return sorted_vals[threshold_id] -def _normalize_image_attr( - attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2 +def _normalize_attr( + attr: ndarray, + sign: str, + outlier_perc: Union[int, float] = 2, + reduction_axis: Optional[int] = None, ): - attr_combined = np.sum(attr, axis=2) + attr_combined = attr + if reduction_axis is not None: + attr_combined = np.sum(attr, axis=reduction_axis) + # Choose appropriate signed values and rescale, removing given outlier percentage. if VisualizeSign[sign] == VisualizeSign.all: threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc) @@ -241,7 +254,7 @@ def visualize_image_attr( plt_axis.imshow(original_image) else: # Choose appropriate signed attributions and normalize. - norm_attr = _normalize_image_attr(attr, sign, outlier_perc) + norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2) # Set default colormap and bounds based on sign. if VisualizeSign[sign] == VisualizeSign.all: @@ -422,6 +435,311 @@ def visualize_image_attr_multiple( return plt_fig, plt_axis +def visualize_timeseries_attr( + attr: ndarray, + data: ndarray, + x_values: Optional[ndarray] = None, + method: str = "individual_channels", + sign: str = "absolute_value", + channel_labels: Optional[List[str]] = None, + channels_last: bool = True, + plt_fig_axis: Union[None, Tuple[figure, axis]] = None, + outlier_perc: Union[int, float] = 2, + cmap: Union[None, str] = None, + alpha_overlay: float = 0.7, + show_colorbar: bool = False, + title: Union[None, str] = None, + fig_size: Tuple[int, int] = (6, 6), + use_pyplot: bool = True, + **pyplot_kwargs, +): + r""" + Visualizes attribution for a given timeseries data by normalizing + attribution values of the desired sign (positive, negative, absolute value, + or all) and displaying them using the desired mode in a matplotlib figure. + + Args: + + attr (numpy.array): Numpy array corresponding to attributions to be + visualized. Shape must be in the form (N, C) with channels + as last dimension, unless `channels_last` is set to True. + Shape must also match that of the timeseries data. + data (numpy.array): Numpy array corresponding to the original, + equidistant timeseries data. Shape must be in the form + (N, C) with channels as last dimension, unless + `channels_last` is set to true. + x_values (numpy.array, optional): Numpy array corresponding to the + points on the x-axis. Shape must be in the form (N, ). If + not provided, integers from 0 to N-1 are used. + Default: None + method (string, optional): Chosen method for visualizing attributions + overlaid onto data. Supported options are: + + 1. `overlay_individual` - Plot each channel individually in + a separate panel, and overlay the attributions for each + channel as a heat map. The `alpha_overlay` parameter + controls the alpha of the heat map. + + 2. `overlay_combined` - Plot all channels in the same panel, + and overlay the average attributions as a heat map. + + 3. `colored_graph` - Plot each channel in a separate panel, + and color the graphs according to the attribution + values. Works best with color maps that does not contain + white or very bright colors. + Default: `overlay_individual` + sign (string, optional): Chosen sign of attributions to visualize. + Supported options are: + + 1. `positive` - Displays only positive pixel attributions. + + 2. `absolute_value` - Displays absolute value of + attributions. + + 3. `negative` - Displays only negative pixel attributions. + + 4. `all` - Displays both positive and negative attribution + values. + Default: `absolute_value` + channel_labels (list of strings, optional): List of labels + corresponding to each channel in data. + Default: None + channels_last (bool, optional): If True, data is expected to have + channels as the last dimension, i.e. (N, C). If False, data + is expected to have channels first, i.e. (C, N). + Default: True + plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis + on which to visualize. If None is provided, then a new figure + and axis are created. + Default: None + outlier_perc (float or int, optional): Top attribution values which + correspond to a total of outlier_perc percentage of the + total attribution are set to 1 and scaling is performed + using the minimum of these values. For sign=`all`, outliers + and scale value are computed using absolute value of + attributions. + Default: 2 + cmap (string, optional): String corresponding to desired colormap for + heatmap visualization. This defaults to "Reds" for negative + sign, "Blues" for absolute value, "Greens" for positive sign, + and a spectrum from red to green for all. Note that this + argument is only used for visualizations displaying heatmaps. + Default: None + alpha_overlay (float, optional): Alpha to set for heatmap when using + `blended_heat_map` visualization mode, which overlays the + heat map over the greyscaled original image. + Default: 0.7 + show_colorbar (boolean): Displays colorbar for heat map below + the visualization. + title (string, optional): Title string for plot. If None, no title is + set. + Default: None + fig_size (tuple, optional): Size of figure created. + Default: (6,6) + use_pyplot (boolean): If true, uses pyplot to create and show + figure and displays the figure after creating. If False, + uses Matplotlib object oriented API and simply returns a + figure object without showing. + Default: True. + pyplot_kwargs: Keyword arguments forwarded to plt.plot, for example + `linewidth=3`, `color='black'`, etc + + Returns: + 2-element tuple of **figure**, **axis**: + - **figure** (*matplotlib.pyplot.figure*): + Figure object on which visualization + is created. If plt_fig_axis argument is given, this is the + same figure provided. + - **axis** (*matplotlib.pyplot.axis*): + Axis object on which visualization + is created. If plt_fig_axis argument is given, this is the + same axis provided. + + Examples:: + + >>> # Classifier takes input of shape (batch, length, channels) + >>> model = Classifier() + >>> dl = DeepLift(model) + >>> attribution = dl.attribute(data, target=0) + >>> # Pick the first sample and plot each channel in data in a separate + >>> # panel, with attributions overlaid + >>> visualize_timeseries_attr(attribution[0], data[0], "overlay_individual") + """ + + # Check input dimensions + assert len(attr.shape) == 2, "Expected attr of shape (N, C), got {}".format( + attr.shape + ) + assert len(data.shape) == 2, "Expected data of shape (N, C), got {}".format( + attr.shape + ) + + # Convert to channels-first + if channels_last: + attr = np.transpose(attr) + data = np.transpose(data) + + num_channels = attr.shape[0] + timeseries_length = attr.shape[1] + + if num_channels > timeseries_length: + warnings.warn( + "Number of channels ({}) greater than time series length ({}), " + "please verify input format".format(num_channels, timeseries_length) + ) + + num_subplots = num_channels + if ( + TimeseriesVisualizationMethod[method] + == TimeseriesVisualizationMethod.overlay_combined + ): + num_subplots = 1 + attr = np.sum(attr, axis=0) # Merge attributions across channels + + if x_values is not None: + assert ( + x_values.shape[0] == timeseries_length + ), "x_values must have same length as data" + else: + x_values = np.arange(timeseries_length) + + # Create plot if figure, axis not provided + if plt_fig_axis is not None: + plt_fig, plt_axis = plt_fig_axis + else: + if use_pyplot: + plt_fig, plt_axis = plt.subplots( + figsize=fig_size, nrows=num_subplots, sharex=True + ) + else: + plt_fig = Figure(figsize=fig_size) + plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) + + if not isinstance(plt_axis, ndarray): + plt_axis = np.array([plt_axis]) + + norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None) + + # Set default colormap and bounds based on sign. + if VisualizeSign[sign] == VisualizeSign.all: + default_cmap = LinearSegmentedColormap.from_list( + "RdWhGn", ["red", "white", "green"] + ) + vmin, vmax = -1, 1 + elif VisualizeSign[sign] == VisualizeSign.positive: + default_cmap = "Greens" + vmin, vmax = 0, 1 + elif VisualizeSign[sign] == VisualizeSign.negative: + default_cmap = "Reds" + vmin, vmax = 0, 1 + elif VisualizeSign[sign] == VisualizeSign.absolute_value: + default_cmap = "Blues" + vmin, vmax = 0, 1 + else: + raise AssertionError("Visualize Sign type is not valid.") + cmap = cmap if cmap is not None else default_cmap + cmap = cm.get_cmap(cmap) + cm_norm = colors.Normalize(vmin, vmax) + + def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): + + half_col_width = (x_values[1] - x_values[0]) / 2.0 + for icol, col_center in enumerate(x_vals): + left = col_center - half_col_width + right = col_center + half_col_width + ax.axvspan( + xmin=left, + xmax=right, + facecolor=(cmap(cm_norm(attr_vals[icol]))), + edgecolor=None, + alpha=alpha_overlay, + ) + + if ( + TimeseriesVisualizationMethod[method] + == TimeseriesVisualizationMethod.overlay_individual + ): + + for chan in range(num_channels): + + plt_axis[chan].plot(x_values, data[chan, :], **pyplot_kwargs) + if channel_labels is not None: + plt_axis[chan].set_ylabel(channel_labels[chan]) + + _plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis[chan]) + + plt.subplots_adjust(hspace=0) + + elif ( + TimeseriesVisualizationMethod[method] + == TimeseriesVisualizationMethod.overlay_combined + ): + + # Dark colors are better in this case + cycler = plt.cycler("color", cm.Dark2.colors) + plt_axis[0].set_prop_cycle(cycler) + + for chan in range(num_channels): + if channel_labels is not None: + label = channel_labels[chan] + else: + label = None + plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs) + + _plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0]) + + plt_axis[0].legend(loc="best") + + elif ( + TimeseriesVisualizationMethod[method] + == TimeseriesVisualizationMethod.colored_graph + ): + + for chan in range(num_channels): + + points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs) + lc.set_array(norm_attr[chan, :]) + plt_axis[chan].add_collection(lc) + plt_axis[chan].set_ylim( + 1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :]) + ) + if channel_labels is not None: + plt_axis[chan].set_ylabel(channel_labels[chan]) + + plt.subplots_adjust(hspace=0) + + else: + raise AssertionError("Invalid visualization method: {}".format(method)) + + plt.xlim([x_values[0], x_values[-1]]) + + if show_colorbar: + axis_separator = make_axes_locatable(plt_axis[-1]) + colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.4) + colorbar_alpha = alpha_overlay + if ( + TimeseriesVisualizationMethod[method] + == TimeseriesVisualizationMethod.colored_graph + ): + colorbar_alpha = 1.0 + plt_fig.colorbar( + cm.ScalarMappable(cm_norm, cmap), + orientation="horizontal", + cax=colorbar_axis, + alpha=colorbar_alpha, + ) + if title: + plt_axis[0].set_title(title) + + if use_pyplot: + plt.show() + + return plt_fig, plt_axis + + # These visualization methods are for text and are partially copied from # experiments conducted by Davide Testuggine at Facebook.