diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 73aca0d440..482ce03fb6 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -5,6 +5,8 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +import matplotlib.colors as mcolors + import matplotlib.pyplot as plt import numpy as np @@ -92,13 +94,28 @@ def plot_token_attr( fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) ) + colors = [ + "#93003a", + "#d0365b", + "#f57789", + "#ffbdc3", + "#ffffff", + "#a4d6e1", + "#73a3ca", + "#4772b3", + "#00429d", + ] + im = ax.imshow( data, vmax=max_abs_attr_val, vmin=-max_abs_attr_val, - cmap="RdYlGn", + cmap=mcolors.LinearSegmentedColormap.from_list( + name="colors", colors=colors + ), aspect="auto", ) + fig.set_facecolor("white") # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore @@ -154,6 +171,8 @@ def plot_seq_attr( data = self.seq_attr.cpu().numpy() + fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) + shortened_tokens = [ shorten(t, width=50, placeholder="...") for t in self.input_tokens ] @@ -161,15 +180,28 @@ def plot_seq_attr( ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) - plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") + plt.setp( + ax.get_xticklabels(), + rotation=-30, + ha="right", + rotation_mode="anchor", + ) + + fig.set_facecolor("white") # pos bar ax.bar( - range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g" + range(data.shape[0]), + [max(v, 0) for v in data], + align="center", + color="#4772b3", ) # neg bar ax.bar( - range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r" + range(data.shape[0]), + [min(v, 0) for v in data], + align="center", + color="#d0365b", ) ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom")