|
5 | 5 |
|
6 | 6 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union |
7 | 7 |
|
| 8 | +import matplotlib.colors as mcolors |
| 9 | + |
8 | 10 | import matplotlib.pyplot as plt |
9 | 11 | import numpy as np |
10 | 12 |
|
@@ -92,13 +94,28 @@ def plot_token_attr( |
92 | 94 | fig.set_size_inches( |
93 | 95 | max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) |
94 | 96 | ) |
| 97 | + colors = [ |
| 98 | + "#00429d", |
| 99 | + "#4772b3", |
| 100 | + "#73a3ca", |
| 101 | + "#a4d6e1", |
| 102 | + "#ffffff", |
| 103 | + "#ffbdc3", |
| 104 | + "#f57789", |
| 105 | + "#d0365b", |
| 106 | + "#93003a", |
| 107 | + ] |
| 108 | + |
95 | 109 | im = ax.imshow( |
96 | 110 | data, |
97 | 111 | vmax=max_abs_attr_val, |
98 | 112 | vmin=-max_abs_attr_val, |
99 | | - cmap="RdYlGn", |
| 113 | + cmap=mcolors.LinearSegmentedColormap.from_list( |
| 114 | + name="colors", colors=colors |
| 115 | + ), |
100 | 116 | aspect="auto", |
101 | 117 | ) |
| 118 | + fig.set_facecolor("white") |
102 | 119 |
|
103 | 120 | # Create colorbar |
104 | 121 | cbar = fig.colorbar(im, ax=ax) # type: ignore |
@@ -154,22 +171,37 @@ def plot_seq_attr( |
154 | 171 |
|
155 | 172 | data = self.seq_attr.cpu().numpy() |
156 | 173 |
|
| 174 | + fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) |
| 175 | + |
157 | 176 | shortened_tokens = [ |
158 | 177 | shorten(t, width=50, placeholder="...") for t in self.input_tokens |
159 | 178 | ] |
160 | 179 | ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) |
161 | 180 |
|
162 | 181 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) |
163 | 182 |
|
164 | | - plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") |
| 183 | + plt.setp( |
| 184 | + ax.get_xticklabels(), |
| 185 | + rotation=-30, |
| 186 | + ha="right", |
| 187 | + rotation_mode="anchor", |
| 188 | + ) |
| 189 | + |
| 190 | + fig.set_facecolor("white") |
165 | 191 |
|
166 | 192 | # pos bar |
167 | 193 | ax.bar( |
168 | | - range(data.shape[0]), [max(v, 0) for v in data], align="center", color="g" |
| 194 | + range(data.shape[0]), |
| 195 | + [max(v, 0) for v in data], |
| 196 | + align="center", |
| 197 | + color="#4772b3", |
169 | 198 | ) |
170 | 199 | # neg bar |
171 | 200 | ax.bar( |
172 | | - range(data.shape[0]), [min(v, 0) for v in data], align="center", color="r" |
| 201 | + range(data.shape[0]), |
| 202 | + [min(v, 0) for v in data], |
| 203 | + align="center", |
| 204 | + color="#d0365b", |
173 | 205 | ) |
174 | 206 |
|
175 | 207 | ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom") |
|
0 commit comments