Skip to content

Commit 6965781

Browse files
csauperfacebook-github-bot
authored andcommitted
llm plot visual improvements
Summary: 1) size of seq plot adapts to data size 2) axes background is white for visibility 3) use red/white/blue scale for visibility + colorblind safety Differential Revision: D63039687
1 parent 7b80c5b commit 6965781

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
77

8+
import matplotlib.colors as mcolors
9+
810
import matplotlib.pyplot as plt
911
import numpy as np
1012

@@ -92,13 +94,28 @@ def plot_token_attr(
9294
fig.set_size_inches(
9395
max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8)
9496
)
97+
colors = [
98+
"#00429d",
99+
"#4772b3",
100+
"#73a3ca",
101+
"#a4d6e1",
102+
"#ffffff",
103+
"#ffbdc3",
104+
"#f57789",
105+
"#d0365b",
106+
"#93003a",
107+
]
108+
95109
im = ax.imshow(
96110
data,
97111
vmax=max_abs_attr_val,
98112
vmin=-max_abs_attr_val,
99-
cmap="RdYlGn",
113+
cmap=mcolors.LinearSegmentedColormap.from_list(
114+
name="colors", colors=colors
115+
),
100116
aspect="auto",
101117
)
118+
fig.set_facecolor("white")
102119

103120
# Create colorbar
104121
cbar = fig.colorbar(im, ax=ax) # type: ignore
@@ -154,22 +171,37 @@ def plot_seq_attr(
154171

155172
data = self.seq_attr.cpu().numpy()
156173

174+
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
175+
157176
shortened_tokens = [
158177
shorten(t, width=50, placeholder="...") for t in self.input_tokens
159178
]
160179
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
161180

162181
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
163182

164-
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
183+
plt.setp(
184+
ax.get_xticklabels(),
185+
rotation=-30,
186+
ha="right",
187+
rotation_mode="anchor",
188+
)
189+
190+
fig.set_facecolor("white")
165191

166192
# pos bar
167193
ax.bar(
168-
range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g"
194+
range(data.shape[0]),
195+
[max(v, 0) for v in data],
196+
align="center",
197+
color="#4772b3",
169198
)
170199
# neg bar
171200
ax.bar(
172-
range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r"
201+
range(data.shape[0]),
202+
[min(v, 0) for v in data],
203+
align="center",
204+
color="#d0365b",
173205
)
174206

175207
ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom")

0 commit comments

Comments
 (0)