Skip to content

Commit 5eb4fb5

Browse files
jjunchofacebook-github-bot
authored andcommitted
'visualize_image_attr' is too complex (#1372)
Summary: Pull Request resolved: #1372 This diff addresses the C901 in visualization.py by breaking down the method Reviewed By: craymichael Differential Revision: D64404009 fbshipit-source-id: 2578c8ebcd620b2c73a838e4caea8ccd968af186
1 parent c68334b commit 5eb4fb5

File tree

1 file changed

+145
-77
lines changed

1 file changed

+145
-77
lines changed

captum/attr/_utils/visualization.py

Lines changed: 145 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
import warnings
55
from enum import Enum
6-
from typing import Any, Iterable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
77

88
import matplotlib
99

@@ -13,6 +13,7 @@
1313
from matplotlib.collections import LineCollection
1414
from matplotlib.colors import Colormap, LinearSegmentedColormap
1515
from matplotlib.figure import Figure
16+
from matplotlib.image import AxesImage
1617
from mpl_toolkits.axes_grid1 import make_axes_locatable
1718
from numpy import ndarray
1819

@@ -88,26 +89,129 @@ def _normalize_attr(
8889
attr_combined = np.sum(attr, axis=reduction_axis)
8990

9091
# Choose appropriate signed values and rescale, removing given outlier percentage.
91-
if VisualizeSign[sign] == VisualizeSign.all:
92+
if VisualizeSign[sign].value == VisualizeSign.all.value:
9293
threshold = _cumulative_sum_threshold(
9394
np.abs(attr_combined), 100.0 - outlier_perc
9495
)
95-
elif VisualizeSign[sign] == VisualizeSign.positive:
96+
elif VisualizeSign[sign].value == VisualizeSign.positive.value:
9697
attr_combined = (attr_combined > 0) * attr_combined
9798
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
98-
elif VisualizeSign[sign] == VisualizeSign.negative:
99+
elif VisualizeSign[sign].value == VisualizeSign.negative.value:
99100
attr_combined = (attr_combined < 0) * attr_combined
100101
threshold = -1 * _cumulative_sum_threshold(
101102
np.abs(attr_combined), 100.0 - outlier_perc
102103
)
103-
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
104+
elif VisualizeSign[sign].value == VisualizeSign.absolute_value.value:
104105
attr_combined = np.abs(attr_combined)
105106
threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc)
106107
else:
107108
raise AssertionError("Visualize Sign type is not valid.")
108109
return _normalize_scale(attr_combined, threshold)
109110

110111

112+
def _initialize_cmap_and_vmin_vmax(
113+
sign: str,
114+
) -> Tuple[Union[str, Colormap], float, float]:
115+
if VisualizeSign[sign].value == VisualizeSign.all.value:
116+
default_cmap: Union[str, LinearSegmentedColormap] = (
117+
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
118+
)
119+
vmin, vmax = -1, 1
120+
elif VisualizeSign[sign].value == VisualizeSign.positive.value:
121+
default_cmap = "Greens"
122+
vmin, vmax = 0, 1
123+
elif VisualizeSign[sign].value == VisualizeSign.negative.value:
124+
default_cmap = "Reds"
125+
vmin, vmax = 0, 1
126+
elif VisualizeSign[sign].value == VisualizeSign.absolute_value.value:
127+
default_cmap = "Blues"
128+
vmin, vmax = 0, 1
129+
else:
130+
raise AssertionError("Visualize Sign type is not valid.")
131+
return default_cmap, vmin, vmax
132+
133+
134+
def _visualize_original_image(
135+
plt_axis: Axes,
136+
original_image: Optional[ndarray],
137+
**kwargs: Any,
138+
) -> None:
139+
assert (
140+
original_image is not None
141+
), "Original image expected for original_image method."
142+
if len(original_image.shape) > 2 and original_image.shape[2] == 1:
143+
original_image = np.squeeze(original_image, axis=2)
144+
plt_axis.imshow(original_image)
145+
146+
147+
def _visualize_heat_map(
148+
plt_axis: Axes,
149+
norm_attr: ndarray,
150+
cmap: Union[str, Colormap],
151+
vmin: float,
152+
vmax: float,
153+
**kwargs: Any,
154+
) -> AxesImage:
155+
heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax)
156+
return heat_map
157+
158+
159+
def _visualize_blended_heat_map(
160+
plt_axis: Axes,
161+
original_image: ndarray,
162+
norm_attr: ndarray,
163+
cmap: Union[str, Colormap],
164+
vmin: float,
165+
vmax: float,
166+
alpha_overlay: float,
167+
**kwargs: Any,
168+
) -> AxesImage:
169+
assert (
170+
original_image is not None
171+
), "Original Image expected for blended_heat_map method."
172+
plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
173+
heat_map = plt_axis.imshow(
174+
norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay
175+
)
176+
return heat_map
177+
178+
179+
def _visualize_masked_image(
180+
plt_axis: Axes,
181+
sign: str,
182+
original_image: ndarray,
183+
norm_attr: ndarray,
184+
**kwargs: Any,
185+
) -> None:
186+
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
187+
"Cannot display masked image with both positive and negative "
188+
"attributions, choose a different sign option."
189+
)
190+
plt_axis.imshow(_prepare_image(original_image * np.expand_dims(norm_attr, 2)))
191+
192+
193+
def _visualize_alpha_scaling(
194+
plt_axis: Axes,
195+
sign: str,
196+
original_image: ndarray,
197+
norm_attr: ndarray,
198+
**kwargs: Any,
199+
) -> None:
200+
assert VisualizeSign[sign].value != VisualizeSign.all.value, (
201+
"Cannot display alpha scaling with both positive and negative "
202+
"attributions, choose a different sign option."
203+
)
204+
plt_axis.imshow(
205+
np.concatenate(
206+
[
207+
original_image,
208+
_prepare_image(np.expand_dims(norm_attr, 2) * 255),
209+
],
210+
axis=2,
211+
)
212+
)
213+
214+
111215
def visualize_image_attr(
112216
attr: ndarray,
113217
original_image: Optional[ndarray] = None,
@@ -242,15 +346,18 @@ def visualize_image_attr(
242346
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
243347
else:
244348
plt_fig = Figure(figsize=fig_size)
245-
plt_axis = plt_fig.subplots() # type: ignore
349+
plt_axis = plt_fig.subplots()
246350
# Figure.subplots returns Axes or array of Axes
247351

248352
if original_image is not None:
249353
if np.max(original_image) <= 1.0:
250354
original_image = _prepare_image(original_image * 255)
251-
elif ImageVisualizationMethod[method] != ImageVisualizationMethod.heat_map:
355+
elif (
356+
ImageVisualizationMethod[method].value
357+
!= ImageVisualizationMethod.heat_map.value
358+
):
252359
raise ValueError(
253-
"Original Image must be provided for"
360+
"Original Image must be provided for "
254361
"any visualization other than heatmap."
255362
)
256363

@@ -261,76 +368,37 @@ def visualize_image_attr(
261368
plt_axis.set_xticklabels([])
262369
plt_axis.grid(visible=False)
263370

264-
heat_map = None
265-
# Show original image
266-
if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image:
267-
assert (
268-
original_image is not None
269-
), "Original image expected for original_image method."
270-
if len(original_image.shape) > 2 and original_image.shape[2] == 1:
271-
original_image = np.squeeze(original_image, axis=2)
272-
plt_axis.imshow(original_image)
273-
else:
274-
# Choose appropriate signed attributions and normalize.
275-
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)
371+
heat_map: Optional[AxesImage] = None
276372

277-
# Set default colormap and bounds based on sign.
278-
if VisualizeSign[sign] == VisualizeSign.all:
279-
default_cmap: Union[str, LinearSegmentedColormap] = (
280-
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
281-
)
282-
vmin, vmax = -1, 1
283-
elif VisualizeSign[sign] == VisualizeSign.positive:
284-
default_cmap = "Greens"
285-
vmin, vmax = 0, 1
286-
elif VisualizeSign[sign] == VisualizeSign.negative:
287-
default_cmap = "Reds"
288-
vmin, vmax = 0, 1
289-
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
290-
default_cmap = "Blues"
291-
vmin, vmax = 0, 1
292-
else:
293-
raise AssertionError("Visualize Sign type is not valid.")
294-
cmap = cmap if cmap is not None else default_cmap
295-
296-
# Show appropriate image visualization.
297-
if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map:
298-
heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax)
299-
elif (
300-
ImageVisualizationMethod[method]
301-
== ImageVisualizationMethod.blended_heat_map
302-
):
303-
assert (
304-
original_image is not None
305-
), "Original Image expected for blended_heat_map method."
306-
plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
307-
heat_map = plt_axis.imshow(
308-
norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay
309-
)
310-
elif ImageVisualizationMethod[method] == ImageVisualizationMethod.masked_image:
311-
assert VisualizeSign[sign] != VisualizeSign.all, (
312-
"Cannot display masked image with both positive and negative "
313-
"attributions, choose a different sign option."
314-
)
315-
plt_axis.imshow(
316-
_prepare_image(original_image * np.expand_dims(norm_attr, 2))
317-
)
318-
elif ImageVisualizationMethod[method] == ImageVisualizationMethod.alpha_scaling:
319-
assert VisualizeSign[sign] != VisualizeSign.all, (
320-
"Cannot display alpha scaling with both positive and negative "
321-
"attributions, choose a different sign option."
322-
)
323-
plt_axis.imshow(
324-
np.concatenate(
325-
[
326-
original_image,
327-
_prepare_image(np.expand_dims(norm_attr, 2) * 255),
328-
],
329-
axis=2,
330-
)
331-
)
332-
else:
333-
raise AssertionError("Visualize Method type is not valid.")
373+
# pyre-ignore[33]: prohibited Any
374+
visualization_methods: Dict[str, Callable[..., Any]] = {
375+
"heat_map": _visualize_heat_map,
376+
"blended_heat_map": _visualize_blended_heat_map,
377+
"masked_image": _visualize_masked_image,
378+
"alpha_scaling": _visualize_alpha_scaling,
379+
"original_image": _visualize_original_image,
380+
}
381+
# Choose appropriate signed attributions and normalize.
382+
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)
383+
384+
# Set default colormap and bounds based on sign.
385+
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)
386+
cmap = cmap if cmap is not None else default_cmap
387+
388+
kwargs = {
389+
"plt_axis": plt_axis,
390+
"original_image": original_image,
391+
"sign": sign,
392+
"cmap": cmap,
393+
"alpha_overlay": alpha_overlay,
394+
"vmin": vmin,
395+
"vmax": vmax,
396+
"norm_attr": norm_attr,
397+
}
398+
if method in visualization_methods:
399+
heat_map = visualization_methods[method](**kwargs)
400+
else:
401+
raise AssertionError("Visualize Method type is not valid.")
334402

335403
# Add colorbar. If given method is not a heatmap and no colormap is relevant,
336404
# then a colormap axis is created and hidden. This is necessary for appropriate

0 commit comments

Comments
 (0)