Skip to content

Commit 2efc105

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Implement viz utils in LLMAttributionResult (#1213)
Summary: as title Pull Request resolved: #1213 Reviewed By: vivekmig Differential Revision: D51628169 Pulled By: aobo-y fbshipit-source-id: 36d20c58113f99ed915aed4019b915fb6a344368
1 parent eaf1f4e commit 2efc105

File tree

1 file changed

+104
-4
lines changed

1 file changed

+104
-4
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from typing import Callable, cast, Dict, List, Optional, Union
44

5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
58
import torch
69
from captum.attr._core.feature_ablation import FeatureAblation
710
from captum.attr._core.kernel_shap import KernelShap
@@ -43,11 +46,108 @@ def __init__(
4346
def seq_attr_dict(self):
4447
return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
4548

46-
def plot_token_attr(self):
47-
pass
49+
def plot_token_attr(self, show=False):
50+
"""
51+
Generate a matplotlib plot for visualising the attribution
52+
of the output tokens.
53+
54+
Args:
55+
show (bool): whether to show the plot directly or return the figure and axis
56+
Default: False
57+
"""
58+
59+
token_attr = self.token_attr.cpu()
60+
61+
# maximum absolute attribution value
62+
# used as the boundary of normalization
63+
# always keep 0 as the mid point to differentiate pos/neg attr
64+
max_abs_attr_val = token_attr.abs().max().item()
65+
66+
fig, ax = plt.subplots()
4867

49-
def plot_seq_attr(self):
50-
pass
68+
# Plot the heatmap
69+
data = token_attr.numpy()
70+
71+
fig.set_size_inches(
72+
max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8)
73+
)
74+
im = ax.imshow(
75+
data,
76+
vmax=max_abs_attr_val,
77+
vmin=-max_abs_attr_val,
78+
cmap="RdYlGn",
79+
aspect="auto",
80+
)
81+
82+
# Create colorbar
83+
cbar = ax.figure.colorbar(im, ax=ax)
84+
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")
85+
86+
# Show all ticks and label them with the respective list entries.
87+
ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens)
88+
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)
89+
90+
# Let the horizontal axes labeling appear on top.
91+
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
92+
93+
# Rotate the tick labels and set their alignment.
94+
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
95+
96+
# Loop over the data and create a `Text` for each "pixel".
97+
# Change the text's color depending on the data.
98+
for i in range(data.shape[0]):
99+
for j in range(data.shape[1]):
100+
val = data[i, j]
101+
color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
102+
im.axes.text(
103+
j,
104+
i,
105+
"%.4f" % val,
106+
horizontalalignment="center",
107+
verticalalignment="center",
108+
color=color,
109+
)
110+
111+
if show:
112+
plt.show()
113+
else:
114+
return fig, ax
115+
116+
def plot_seq_attr(self, show=False):
117+
"""
118+
Generate a matplotlib plot for visualising the attribution
119+
of the output sequence.
120+
121+
Args:
122+
show (bool): whether to show the plot directly or return the figure and axis
123+
Default: False
124+
"""
125+
126+
fig, ax = plt.subplots()
127+
128+
data = self.seq_attr.cpu().numpy()
129+
130+
ax.set_xticks(range(data.shape[0]), labels=self.input_tokens)
131+
132+
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
133+
134+
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
135+
136+
# pos bar
137+
ax.bar(
138+
range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g"
139+
)
140+
# neg bar
141+
ax.bar(
142+
range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r"
143+
)
144+
145+
ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom")
146+
147+
if show:
148+
plt.show()
149+
else:
150+
return fig, ax
51151

52152

53153
class LLMAttribution(Attribution):

0 commit comments

Comments
 (0)