@@ -109,6 +109,28 @@ def _normalize_attr(
109109 return _normalize_scale (attr_combined , threshold )
110110
111111
112+ def _create_default_plot (
113+ # pyre-fixme[2]: Parameter must be annotated.
114+ plt_fig_axis ,
115+ # pyre-fixme[2]: Parameter must be annotated.
116+ use_pyplot ,
117+ # pyre-fixme[2]: Parameter must be annotated.
118+ fig_size ,
119+ ** pyplot_kwargs : Any ,
120+ ) -> Tuple [Figure , Axes ]:
121+ # Create plot if figure, axis not provided
122+ if plt_fig_axis is not None :
123+ plt_fig , plt_axis = plt_fig_axis
124+ else :
125+ if use_pyplot :
126+ plt_fig , plt_axis = plt .subplots (figsize = fig_size , ** pyplot_kwargs )
127+ else :
128+ plt_fig = Figure (figsize = fig_size )
129+ plt_axis = plt_fig .subplots (** pyplot_kwargs )
130+ return plt_fig , plt_axis
131+ # Figure.subplots returns Axes or array of Axes
132+
133+
112134def _initialize_cmap_and_vmin_vmax (
113135 sign : str ,
114136) -> Tuple [Union [str , Colormap ], float , float ]:
@@ -338,16 +360,7 @@ def visualize_image_attr(
338360 >>> # Displays blended heat map visualization of computed attributions.
339361 >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
340362 """
341- # Create plot if figure, axis not provided
342- if plt_fig_axis is not None :
343- plt_fig , plt_axis = plt_fig_axis
344- else :
345- if use_pyplot :
346- plt_fig , plt_axis = plt .subplots (figsize = fig_size )
347- else :
348- plt_fig = Figure (figsize = fig_size )
349- plt_axis = plt_fig .subplots ()
350- # Figure.subplots returns Axes or array of Axes
363+ plt_fig , plt_axis = _create_default_plot (plt_fig_axis , use_pyplot , fig_size )
351364
352365 if original_image is not None :
353366 if np .max (original_image ) <= 1.0 :
@@ -362,8 +375,10 @@ def visualize_image_attr(
362375 )
363376
364377 # Remove ticks and tick labels from plot.
365- plt_axis .xaxis .set_ticks_position ("none" )
366- plt_axis .yaxis .set_ticks_position ("none" )
378+ if plt_axis .xaxis is not None :
379+ plt_axis .xaxis .set_ticks_position ("none" )
380+ if plt_axis .yaxis is not None :
381+ plt_axis .yaxis .set_ticks_position ("none" )
367382 plt_axis .set_yticklabels ([])
368383 plt_axis .set_xticklabels ([])
369384 plt_axis .grid (visible = False )
@@ -528,6 +543,161 @@ def visualize_image_attr_multiple(
528543 return plt_fig , plt_axis
529544
530545
546+ def _plot_attrs_as_axvspan (
547+ # pyre-fixme[2]: Parameter must be annotated.
548+ attr_vals ,
549+ # pyre-fixme[2]: Parameter must be annotated.
550+ x_vals ,
551+ # pyre-fixme[2]: Parameter must be annotated.
552+ ax ,
553+ # pyre-fixme[2]: Parameter must be annotated.
554+ x_values ,
555+ # pyre-fixme[2]: Parameter must be annotated.
556+ cmap ,
557+ # pyre-fixme[2]: Parameter must be annotated.
558+ cm_norm ,
559+ # pyre-fixme[2]: Parameter must be annotated.
560+ alpha_overlay ,
561+ ) -> None :
562+ # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
563+ half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
564+
565+ for icol , col_center in enumerate (x_vals ):
566+ left = col_center - half_col_width
567+ right = col_center + half_col_width
568+ ax .axvspan (
569+ xmin = left ,
570+ xmax = right ,
571+ # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
572+ facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
573+ edgecolor = None ,
574+ alpha = alpha_overlay ,
575+ )
576+
577+
578+ def _visualize_overlay_individual (
579+ # pyre-fixme[2]: Parameter must be annotated.
580+ num_channels ,
581+ # pyre-fixme[2]: Parameter must be annotated.
582+ plt_axis_list ,
583+ # pyre-fixme[2]: Parameter must be annotated.
584+ x_values ,
585+ # pyre-fixme[2]: Parameter must be annotated.
586+ data ,
587+ # pyre-fixme[2]: Parameter must be annotated.
588+ channel_labels ,
589+ # pyre-fixme[2]: Parameter must be annotated.
590+ norm_attr ,
591+ # pyre-fixme[2]: Parameter must be annotated.
592+ cmap ,
593+ # pyre-fixme[2]: Parameter must be annotated.
594+ cm_norm ,
595+ # pyre-fixme[2]: Parameter must be annotated.
596+ alpha_overlay ,
597+ # pyre-fixme[2]: Parameter must be annotated.
598+ ** kwargs : Any ,
599+ ) -> None :
600+ # helper method for visualize_timeseries_attr
601+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
602+ for chan in range (num_channels ):
603+ plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
604+ if channel_labels is not None :
605+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
606+
607+ _plot_attrs_as_axvspan (
608+ norm_attr [chan ],
609+ x_values ,
610+ plt_axis_list [chan ],
611+ x_values ,
612+ cmap ,
613+ cm_norm ,
614+ alpha_overlay ,
615+ )
616+
617+ plt .subplots_adjust (hspace = 0 )
618+ pass
619+
620+
621+ def _visualize_overlay_combined (
622+ # pyre-fixme[2]: Parameter must be annotated.
623+ num_channels ,
624+ # pyre-fixme[2]: Parameter must be annotated.
625+ plt_axis_list ,
626+ # pyre-fixme[2]: Parameter must be annotated.
627+ x_values ,
628+ # pyre-fixme[2]: Parameter must be annotated.
629+ data ,
630+ # pyre-fixme[2]: Parameter must be annotated.
631+ channel_labels ,
632+ # pyre-fixme[2]: Parameter must be annotated.
633+ norm_attr ,
634+ # pyre-fixme[2]: Parameter must be annotated.
635+ cmap ,
636+ # pyre-fixme[2]: Parameter must be annotated.
637+ cm_norm ,
638+ # pyre-fixme[2]: Parameter must be annotated.
639+ alpha_overlay ,
640+ ** kwargs : Any ,
641+ ) -> None :
642+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
643+
644+ cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ].colors ) # type: ignore
645+ plt_axis_list [0 ].set_prop_cycle (cycler )
646+
647+ for chan in range (num_channels ):
648+ label = channel_labels [chan ] if channel_labels else None
649+ plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
650+
651+ _plot_attrs_as_axvspan (
652+ norm_attr ,
653+ x_values ,
654+ plt_axis_list [0 ],
655+ x_values ,
656+ cmap ,
657+ cm_norm ,
658+ alpha_overlay ,
659+ )
660+
661+ plt_axis_list [0 ].legend (loc = "best" )
662+
663+
664+ def _visualize_colored_graph (
665+ # pyre-fixme[2]: Parameter must be annotated.
666+ num_channels ,
667+ # pyre-fixme[2]: Parameter must be annotated.
668+ plt_axis_list ,
669+ # pyre-fixme[2]: Parameter must be annotated.
670+ x_values ,
671+ # pyre-fixme[2]: Parameter must be annotated.
672+ data ,
673+ # pyre-fixme[2]: Parameter must be annotated.
674+ channel_labels ,
675+ # pyre-fixme[2]: Parameter must be annotated.
676+ norm_attr ,
677+ # pyre-fixme[2]: Parameter must be annotated.
678+ cmap ,
679+ # pyre-fixme[2]: Parameter must be annotated.
680+ cm_norm ,
681+ ** kwargs : Any ,
682+ ) -> None :
683+ # helper method for visualize_timeseries_attr
684+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
685+ for chan in range (num_channels ):
686+ points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
687+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
688+
689+ lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
690+ lc .set_array (norm_attr [chan , :])
691+ plt_axis_list [chan ].add_collection (lc )
692+ plt_axis_list [chan ].set_ylim (
693+ 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
694+ )
695+ if channel_labels is not None :
696+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
697+
698+ plt .subplots_adjust (hspace = 0 )
699+
700+
531701def visualize_timeseries_attr (
532702 attr : npt .NDArray ,
533703 data : npt .NDArray ,
@@ -686,8 +856,8 @@ def visualize_timeseries_attr(
686856
687857 num_subplots = num_channels
688858 if (
689- TimeseriesVisualizationMethod [method ]
690- == TimeseriesVisualizationMethod .overlay_combined
859+ TimeseriesVisualizationMethod [method ]. value
860+ == TimeseriesVisualizationMethod .overlay_combined . value
691861 ):
692862 num_subplots = 1
693863 attr = np .sum (attr , axis = 0 ) # Merge attributions across channels
@@ -700,17 +870,9 @@ def visualize_timeseries_attr(
700870 x_values = np .arange (timeseries_length )
701871
702872 # Create plot if figure, axis not provided
703- if plt_fig_axis is not None :
704- plt_fig , plt_axis = plt_fig_axis
705- else :
706- if use_pyplot :
707- plt_fig , plt_axis = plt .subplots ( # type: ignore
708- figsize = fig_size , nrows = num_subplots , sharex = True
709- )
710- else :
711- plt_fig = Figure (figsize = fig_size )
712- plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True ) # type: ignore
713- # Figure.subplots returns Axes or array of Axes
873+ plt_fig , plt_axis = _create_default_plot (
874+ plt_fig_axis , use_pyplot , fig_size , nrows = num_subplots , sharex = True
875+ )
714876
715877 if not isinstance (plt_axis , ndarray ):
716878 plt_axis_list = np .array ([plt_axis ])
@@ -720,91 +882,30 @@ def visualize_timeseries_attr(
720882 norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
721883
722884 # Set default colormap and bounds based on sign.
723- if VisualizeSign [sign ] == VisualizeSign .all :
724- default_cmap : Union [str , LinearSegmentedColormap ] = (
725- LinearSegmentedColormap .from_list ("RdWhGn" , ["red" , "white" , "green" ])
726- )
727- vmin , vmax = - 1 , 1
728- elif VisualizeSign [sign ] == VisualizeSign .positive :
729- default_cmap = "Greens"
730- vmin , vmax = 0 , 1
731- elif VisualizeSign [sign ] == VisualizeSign .negative :
732- default_cmap = "Reds"
733- vmin , vmax = 0 , 1
734- elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
735- default_cmap = "Blues"
736- vmin , vmax = 0 , 1
737- else :
738- raise AssertionError ("Visualize Sign type is not valid." )
885+ default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
739886 cmap = cmap if cmap is not None else default_cmap
740887 cmap = cm .get_cmap (cmap ) # type: ignore
741888 cm_norm = colors .Normalize (vmin , vmax )
742889
743- # pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
744- # pyre-fixme[2]: Parameter must be annotated.
745- def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ) -> None :
746- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
747- half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
748- for icol , col_center in enumerate (x_vals ):
749- left = col_center - half_col_width
750- right = col_center + half_col_width
751- ax .axvspan (
752- xmin = left ,
753- xmax = right ,
754- # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
755- facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
756- edgecolor = None ,
757- alpha = alpha_overlay ,
758- )
759-
760- if (
761- TimeseriesVisualizationMethod [method ]
762- == TimeseriesVisualizationMethod .overlay_individual
763- ):
764- for chan in range (num_channels ):
765- plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
766- if channel_labels is not None :
767- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
768-
769- _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis_list [chan ])
770-
771- plt .subplots_adjust (hspace = 0 )
772-
773- elif (
774- TimeseriesVisualizationMethod [method ]
775- == TimeseriesVisualizationMethod .overlay_combined
776- ):
777- # Dark colors are better in this case
778- cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ]) # type: ignore
779- plt_axis_list [0 ].set_prop_cycle (cycler )
780-
781- for chan in range (num_channels ):
782- label = channel_labels [chan ] if channel_labels else None
783- plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
784-
785- _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis_list [0 ])
786-
787- plt_axis_list [0 ].legend (loc = "best" )
788-
789- elif (
790- TimeseriesVisualizationMethod [method ]
791- == TimeseriesVisualizationMethod .colored_graph
792- ):
793- for chan in range (num_channels ):
794- points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
795- segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
796-
797- lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
798- lc .set_array (norm_attr [chan , :])
799- plt_axis_list [chan ].add_collection (lc )
800- plt_axis_list [chan ].set_ylim (
801- 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
802- )
803- if channel_labels is not None :
804- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
805-
806- plt .subplots_adjust (hspace = 0 )
807-
890+ visualization_methods : Dict [str , Callable [..., Union [None , AxesImage ]]] = {
891+ "overlay_individual" : _visualize_overlay_individual ,
892+ "overlay_combined" : _visualize_overlay_combined ,
893+ "colored_graph" : _visualize_colored_graph ,
894+ }
895+ kwargs = {
896+ "num_channels" : num_channels ,
897+ "plt_axis_list" : plt_axis_list ,
898+ "x_values" : x_values ,
899+ "data" : data ,
900+ "channel_labels" : channel_labels ,
901+ "norm_attr" : norm_attr ,
902+ "cmap" : cmap ,
903+ "cm_norm" : cm_norm ,
904+ "alpha_overlay" : alpha_overlay ,
905+ "pyplot_kwargs" : pyplot_kwargs ,
906+ }
907+ if method in visualization_methods :
908+ visualization_methods [method ](** kwargs )
808909 else :
809910 raise AssertionError ("Invalid visualization method: {}" .format (method ))
810911
0 commit comments