|
2 | 2 |
|
3 | 3 | from typing import Callable, cast, Dict, List, Optional, Union
|
4 | 4 |
|
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | + |
5 | 8 | import torch
|
6 | 9 | from captum.attr._core.feature_ablation import FeatureAblation
|
7 | 10 | from captum.attr._core.kernel_shap import KernelShap
|
@@ -43,11 +46,108 @@ def __init__(
|
43 | 46 | def seq_attr_dict(self):
|
44 | 47 | return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
|
45 | 48 |
|
46 |
| - def plot_token_attr(self): |
47 |
| - pass |
| 49 | + def plot_token_attr(self, show=False): |
| 50 | + """ |
| 51 | + Generate a matplotlib plot for visualising the attribution |
| 52 | + of the output tokens. |
| 53 | +
|
| 54 | + Args: |
| 55 | + show (bool): whether to show the plot directly or return the figure and axis |
| 56 | + Default: False |
| 57 | + """ |
| 58 | + |
| 59 | + token_attr = self.token_attr.cpu() |
| 60 | + |
| 61 | + # maximum absolute attribution value |
| 62 | + # used as the boundary of normalization |
| 63 | + # always keep 0 as the mid point to differentiate pos/neg attr |
| 64 | + max_abs_attr_val = token_attr.abs().max().item() |
| 65 | + |
| 66 | + fig, ax = plt.subplots() |
48 | 67 |
|
49 |
| - def plot_seq_attr(self): |
50 |
| - pass |
| 68 | + # Plot the heatmap |
| 69 | + data = token_attr.numpy() |
| 70 | + |
| 71 | + fig.set_size_inches( |
| 72 | + max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) |
| 73 | + ) |
| 74 | + im = ax.imshow( |
| 75 | + data, |
| 76 | + vmax=max_abs_attr_val, |
| 77 | + vmin=-max_abs_attr_val, |
| 78 | + cmap="RdYlGn", |
| 79 | + aspect="auto", |
| 80 | + ) |
| 81 | + |
| 82 | + # Create colorbar |
| 83 | + cbar = ax.figure.colorbar(im, ax=ax) |
| 84 | + cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom") |
| 85 | + |
| 86 | + # Show all ticks and label them with the respective list entries. |
| 87 | + ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens) |
| 88 | + ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) |
| 89 | + |
| 90 | + # Let the horizontal axes labeling appear on top. |
| 91 | + ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) |
| 92 | + |
| 93 | + # Rotate the tick labels and set their alignment. |
| 94 | + plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") |
| 95 | + |
| 96 | + # Loop over the data and create a `Text` for each "pixel". |
| 97 | + # Change the text's color depending on the data. |
| 98 | + for i in range(data.shape[0]): |
| 99 | + for j in range(data.shape[1]): |
| 100 | + val = data[i, j] |
| 101 | + color = "black" if 0.2 < im.norm(val) < 0.8 else "white" |
| 102 | + im.axes.text( |
| 103 | + j, |
| 104 | + i, |
| 105 | + "%.4f" % val, |
| 106 | + horizontalalignment="center", |
| 107 | + verticalalignment="center", |
| 108 | + color=color, |
| 109 | + ) |
| 110 | + |
| 111 | + if show: |
| 112 | + plt.show() |
| 113 | + else: |
| 114 | + return fig, ax |
| 115 | + |
| 116 | + def plot_seq_attr(self, show=False): |
| 117 | + """ |
| 118 | + Generate a matplotlib plot for visualising the attribution |
| 119 | + of the output sequence. |
| 120 | +
|
| 121 | + Args: |
| 122 | + show (bool): whether to show the plot directly or return the figure and axis |
| 123 | + Default: False |
| 124 | + """ |
| 125 | + |
| 126 | + fig, ax = plt.subplots() |
| 127 | + |
| 128 | + data = self.seq_attr.cpu().numpy() |
| 129 | + |
| 130 | + ax.set_xticks(range(data.shape[0]), labels=self.input_tokens) |
| 131 | + |
| 132 | + ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) |
| 133 | + |
| 134 | + plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") |
| 135 | + |
| 136 | + # pos bar |
| 137 | + ax.bar( |
| 138 | + range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g" |
| 139 | + ) |
| 140 | + # neg bar |
| 141 | + ax.bar( |
| 142 | + range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r" |
| 143 | + ) |
| 144 | + |
| 145 | + ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom") |
| 146 | + |
| 147 | + if show: |
| 148 | + plt.show() |
| 149 | + else: |
| 150 | + return fig, ax |
51 | 151 |
|
52 | 152 |
|
53 | 153 | class LLMAttribution(Attribution):
|
|
0 commit comments