Skip to content

Commit 8743239

Browse files
jjunchofacebook-github-bot
authored andcommitted
'visualize_timeseries_attr' is too complex (#1384)
Summary: This diff addresses the C901 in visualization.py by breaking down the method Differential Revision: D64513163
1 parent 8dc48a5 commit 8743239

File tree

1 file changed

+207
-106
lines changed

1 file changed

+207
-106
lines changed

captum/attr/_utils/visualization.py

Lines changed: 207 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,28 @@ def _normalize_attr(
106106
return _normalize_scale(attr_combined, threshold)
107107

108108

109+
def _create_default_plot(
110+
# pyre-fixme[2]: Parameter must be annotated.
111+
plt_fig_axis,
112+
# pyre-fixme[2]: Parameter must be annotated.
113+
use_pyplot,
114+
# pyre-fixme[2]: Parameter must be annotated.
115+
fig_size,
116+
**pyplot_kwargs: Any,
117+
) -> Tuple[Figure, Axes]:
118+
# Create plot if figure, axis not provided
119+
if plt_fig_axis is not None:
120+
plt_fig, plt_axis = plt_fig_axis
121+
else:
122+
if use_pyplot:
123+
plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs)
124+
else:
125+
plt_fig = Figure(figsize=fig_size)
126+
plt_axis = plt_fig.subplots(**pyplot_kwargs)
127+
return plt_fig, plt_axis
128+
# Figure.subplots returns Axes or array of Axes
129+
130+
109131
def _initialize_cmap_and_vmin_vmax(
110132
sign: str,
111133
) -> Tuple[Union[str, Colormap], float, float]:
@@ -335,16 +357,7 @@ def visualize_image_attr(
335357
>>> # Displays blended heat map visualization of computed attributions.
336358
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
337359
"""
338-
# Create plot if figure, axis not provided
339-
if plt_fig_axis is not None:
340-
plt_fig, plt_axis = plt_fig_axis
341-
else:
342-
if use_pyplot:
343-
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
344-
else:
345-
plt_fig = Figure(figsize=fig_size)
346-
plt_axis = plt_fig.subplots()
347-
# Figure.subplots returns Axes or array of Axes
360+
plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size)
348361

349362
if original_image is not None:
350363
if np.max(original_image) <= 1.0:
@@ -359,8 +372,10 @@ def visualize_image_attr(
359372
)
360373

361374
# Remove ticks and tick labels from plot.
362-
plt_axis.xaxis.set_ticks_position("none")
363-
plt_axis.yaxis.set_ticks_position("none")
375+
if plt_axis.xaxis is not None:
376+
plt_axis.xaxis.set_ticks_position("none")
377+
if plt_axis.yaxis is not None:
378+
plt_axis.yaxis.set_ticks_position("none")
364379
plt_axis.set_yticklabels([])
365380
plt_axis.set_xticklabels([])
366381
plt_axis.grid(visible=False)
@@ -525,6 +540,161 @@ def visualize_image_attr_multiple(
525540
return plt_fig, plt_axis
526541

527542

543+
def _plot_attrs_as_axvspan(
544+
# pyre-fixme[2]: Parameter must be annotated.
545+
attr_vals,
546+
# pyre-fixme[2]: Parameter must be annotated.
547+
x_vals,
548+
# pyre-fixme[2]: Parameter must be annotated.
549+
ax,
550+
# pyre-fixme[2]: Parameter must be annotated.
551+
x_values,
552+
# pyre-fixme[2]: Parameter must be annotated.
553+
cmap,
554+
# pyre-fixme[2]: Parameter must be annotated.
555+
cm_norm,
556+
# pyre-fixme[2]: Parameter must be annotated.
557+
alpha_overlay,
558+
) -> None:
559+
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
560+
half_col_width = (x_values[1] - x_values[0]) / 2.0
561+
562+
for icol, col_center in enumerate(x_vals):
563+
left = col_center - half_col_width
564+
right = col_center + half_col_width
565+
ax.axvspan(
566+
xmin=left,
567+
xmax=right,
568+
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
569+
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
570+
edgecolor=None,
571+
alpha=alpha_overlay,
572+
)
573+
574+
575+
def _visualize_overlay_individual(
576+
# pyre-fixme[2]: Parameter must be annotated.
577+
num_channels,
578+
# pyre-fixme[2]: Parameter must be annotated.
579+
plt_axis_list,
580+
# pyre-fixme[2]: Parameter must be annotated.
581+
x_values,
582+
# pyre-fixme[2]: Parameter must be annotated.
583+
data,
584+
# pyre-fixme[2]: Parameter must be annotated.
585+
channel_labels,
586+
# pyre-fixme[2]: Parameter must be annotated.
587+
norm_attr,
588+
# pyre-fixme[2]: Parameter must be annotated.
589+
cmap,
590+
# pyre-fixme[2]: Parameter must be annotated.
591+
cm_norm,
592+
# pyre-fixme[2]: Parameter must be annotated.
593+
alpha_overlay,
594+
# pyre-fixme[2]: Parameter must be annotated.
595+
**kwargs: Any,
596+
) -> None:
597+
# helper method for visualize_timeseries_attr
598+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
599+
for chan in range(num_channels):
600+
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
601+
if channel_labels is not None:
602+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
603+
604+
_plot_attrs_as_axvspan(
605+
norm_attr[chan],
606+
x_values,
607+
plt_axis_list[chan],
608+
x_values,
609+
cmap,
610+
cm_norm,
611+
alpha_overlay,
612+
)
613+
614+
plt.subplots_adjust(hspace=0)
615+
pass
616+
617+
618+
def _visualize_overlay_combined(
619+
# pyre-fixme[2]: Parameter must be annotated.
620+
num_channels,
621+
# pyre-fixme[2]: Parameter must be annotated.
622+
plt_axis_list,
623+
# pyre-fixme[2]: Parameter must be annotated.
624+
x_values,
625+
# pyre-fixme[2]: Parameter must be annotated.
626+
data,
627+
# pyre-fixme[2]: Parameter must be annotated.
628+
channel_labels,
629+
# pyre-fixme[2]: Parameter must be annotated.
630+
norm_attr,
631+
# pyre-fixme[2]: Parameter must be annotated.
632+
cmap,
633+
# pyre-fixme[2]: Parameter must be annotated.
634+
cm_norm,
635+
# pyre-fixme[2]: Parameter must be annotated.
636+
alpha_overlay,
637+
**kwargs: Any,
638+
) -> None:
639+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
640+
641+
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"].colors) # type: ignore
642+
plt_axis_list[0].set_prop_cycle(cycler)
643+
644+
for chan in range(num_channels):
645+
label = channel_labels[chan] if channel_labels else None
646+
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
647+
648+
_plot_attrs_as_axvspan(
649+
norm_attr,
650+
x_values,
651+
plt_axis_list[0],
652+
x_values,
653+
cmap,
654+
cm_norm,
655+
alpha_overlay,
656+
)
657+
658+
plt_axis_list[0].legend(loc="best")
659+
660+
661+
def _visualize_colored_graph(
662+
# pyre-fixme[2]: Parameter must be annotated.
663+
num_channels,
664+
# pyre-fixme[2]: Parameter must be annotated.
665+
plt_axis_list,
666+
# pyre-fixme[2]: Parameter must be annotated.
667+
x_values,
668+
# pyre-fixme[2]: Parameter must be annotated.
669+
data,
670+
# pyre-fixme[2]: Parameter must be annotated.
671+
channel_labels,
672+
# pyre-fixme[2]: Parameter must be annotated.
673+
norm_attr,
674+
# pyre-fixme[2]: Parameter must be annotated.
675+
cmap,
676+
# pyre-fixme[2]: Parameter must be annotated.
677+
cm_norm,
678+
**kwargs: Any,
679+
) -> None:
680+
# helper method for visualize_timeseries_attr
681+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
682+
for chan in range(num_channels):
683+
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
684+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
685+
686+
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
687+
lc.set_array(norm_attr[chan, :])
688+
plt_axis_list[chan].add_collection(lc)
689+
plt_axis_list[chan].set_ylim(
690+
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
691+
)
692+
if channel_labels is not None:
693+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
694+
695+
plt.subplots_adjust(hspace=0)
696+
697+
528698
def visualize_timeseries_attr(
529699
attr: ndarray,
530700
data: ndarray,
@@ -683,8 +853,8 @@ def visualize_timeseries_attr(
683853

684854
num_subplots = num_channels
685855
if (
686-
TimeseriesVisualizationMethod[method]
687-
== TimeseriesVisualizationMethod.overlay_combined
856+
TimeseriesVisualizationMethod[method].value
857+
== TimeseriesVisualizationMethod.overlay_combined.value
688858
):
689859
num_subplots = 1
690860
attr = np.sum(attr, axis=0) # Merge attributions across channels
@@ -697,17 +867,9 @@ def visualize_timeseries_attr(
697867
x_values = np.arange(timeseries_length)
698868

699869
# Create plot if figure, axis not provided
700-
if plt_fig_axis is not None:
701-
plt_fig, plt_axis = plt_fig_axis
702-
else:
703-
if use_pyplot:
704-
plt_fig, plt_axis = plt.subplots( # type: ignore
705-
figsize=fig_size, nrows=num_subplots, sharex=True
706-
)
707-
else:
708-
plt_fig = Figure(figsize=fig_size)
709-
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore
710-
# Figure.subplots returns Axes or array of Axes
870+
plt_fig, plt_axis = _create_default_plot(
871+
plt_fig_axis, use_pyplot, fig_size, nrows=num_subplots, sharex=True
872+
)
711873

712874
if not isinstance(plt_axis, ndarray):
713875
plt_axis_list = np.array([plt_axis])
@@ -717,91 +879,30 @@ def visualize_timeseries_attr(
717879
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)
718880

719881
# Set default colormap and bounds based on sign.
720-
if VisualizeSign[sign] == VisualizeSign.all:
721-
default_cmap: Union[str, LinearSegmentedColormap] = (
722-
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
723-
)
724-
vmin, vmax = -1, 1
725-
elif VisualizeSign[sign] == VisualizeSign.positive:
726-
default_cmap = "Greens"
727-
vmin, vmax = 0, 1
728-
elif VisualizeSign[sign] == VisualizeSign.negative:
729-
default_cmap = "Reds"
730-
vmin, vmax = 0, 1
731-
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
732-
default_cmap = "Blues"
733-
vmin, vmax = 0, 1
734-
else:
735-
raise AssertionError("Visualize Sign type is not valid.")
882+
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)
736883
cmap = cmap if cmap is not None else default_cmap
737884
cmap = cm.get_cmap(cmap) # type: ignore
738885
cm_norm = colors.Normalize(vmin, vmax)
739886

740-
# pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
741-
# pyre-fixme[2]: Parameter must be annotated.
742-
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax) -> None:
743-
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
744-
half_col_width = (x_values[1] - x_values[0]) / 2.0
745-
for icol, col_center in enumerate(x_vals):
746-
left = col_center - half_col_width
747-
right = col_center + half_col_width
748-
ax.axvspan(
749-
xmin=left,
750-
xmax=right,
751-
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
752-
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
753-
edgecolor=None,
754-
alpha=alpha_overlay,
755-
)
756-
757-
if (
758-
TimeseriesVisualizationMethod[method]
759-
== TimeseriesVisualizationMethod.overlay_individual
760-
):
761-
for chan in range(num_channels):
762-
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
763-
if channel_labels is not None:
764-
plt_axis_list[chan].set_ylabel(channel_labels[chan])
765-
766-
_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan])
767-
768-
plt.subplots_adjust(hspace=0)
769-
770-
elif (
771-
TimeseriesVisualizationMethod[method]
772-
== TimeseriesVisualizationMethod.overlay_combined
773-
):
774-
# Dark colors are better in this case
775-
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore
776-
plt_axis_list[0].set_prop_cycle(cycler)
777-
778-
for chan in range(num_channels):
779-
label = channel_labels[chan] if channel_labels else None
780-
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
781-
782-
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0])
783-
784-
plt_axis_list[0].legend(loc="best")
785-
786-
elif (
787-
TimeseriesVisualizationMethod[method]
788-
== TimeseriesVisualizationMethod.colored_graph
789-
):
790-
for chan in range(num_channels):
791-
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
792-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
793-
794-
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
795-
lc.set_array(norm_attr[chan, :])
796-
plt_axis_list[chan].add_collection(lc)
797-
plt_axis_list[chan].set_ylim(
798-
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
799-
)
800-
if channel_labels is not None:
801-
plt_axis_list[chan].set_ylabel(channel_labels[chan])
802-
803-
plt.subplots_adjust(hspace=0)
804-
887+
visualization_methods: Dict[str, Callable[..., Union[None, AxesImage]]] = {
888+
"overlay_individual": _visualize_overlay_individual,
889+
"overlay_combined": _visualize_overlay_combined,
890+
"colored_graph": _visualize_colored_graph,
891+
}
892+
kwargs = {
893+
"num_channels": num_channels,
894+
"plt_axis_list": plt_axis_list,
895+
"x_values": x_values,
896+
"data": data,
897+
"channel_labels": channel_labels,
898+
"norm_attr": norm_attr,
899+
"cmap": cmap,
900+
"cm_norm": cm_norm,
901+
"alpha_overlay": alpha_overlay,
902+
"pyplot_kwargs": pyplot_kwargs,
903+
}
904+
if method in visualization_methods:
905+
visualization_methods[method](**kwargs)
805906
else:
806907
raise AssertionError("Invalid visualization method: {}".format(method))
807908

0 commit comments

Comments
 (0)