@@ -106,6 +106,28 @@ def _normalize_attr(
106106 return _normalize_scale (attr_combined , threshold )
107107
108108
109+ def _create_default_plot (
110+ # pyre-fixme[2]: Parameter must be annotated.
111+ plt_fig_axis ,
112+ # pyre-fixme[2]: Parameter must be annotated.
113+ use_pyplot ,
114+ # pyre-fixme[2]: Parameter must be annotated.
115+ fig_size ,
116+ ** pyplot_kwargs : Any ,
117+ ) -> Tuple [Figure , Axes ]:
118+ # Create plot if figure, axis not provided
119+ if plt_fig_axis is not None :
120+ plt_fig , plt_axis = plt_fig_axis
121+ else :
122+ if use_pyplot :
123+ plt_fig , plt_axis = plt .subplots (figsize = fig_size , ** pyplot_kwargs )
124+ else :
125+ plt_fig = Figure (figsize = fig_size )
126+ plt_axis = plt_fig .subplots (** pyplot_kwargs )
127+ return plt_fig , plt_axis
128+ # Figure.subplots returns Axes or array of Axes
129+
130+
109131def _initialize_cmap_and_vmin_vmax (
110132 sign : str ,
111133) -> Tuple [Union [str , Colormap ], float , float ]:
@@ -335,16 +357,7 @@ def visualize_image_attr(
335357 >>> # Displays blended heat map visualization of computed attributions.
336358 >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
337359 """
338- # Create plot if figure, axis not provided
339- if plt_fig_axis is not None :
340- plt_fig , plt_axis = plt_fig_axis
341- else :
342- if use_pyplot :
343- plt_fig , plt_axis = plt .subplots (figsize = fig_size )
344- else :
345- plt_fig = Figure (figsize = fig_size )
346- plt_axis = plt_fig .subplots ()
347- # Figure.subplots returns Axes or array of Axes
360+ plt_fig , plt_axis = _create_default_plot (plt_fig_axis , use_pyplot , fig_size )
348361
349362 if original_image is not None :
350363 if np .max (original_image ) <= 1.0 :
@@ -359,8 +372,10 @@ def visualize_image_attr(
359372 )
360373
361374 # Remove ticks and tick labels from plot.
362- plt_axis .xaxis .set_ticks_position ("none" )
363- plt_axis .yaxis .set_ticks_position ("none" )
375+ if plt_axis .xaxis is not None :
376+ plt_axis .xaxis .set_ticks_position ("none" )
377+ if plt_axis .yaxis is not None :
378+ plt_axis .yaxis .set_ticks_position ("none" )
364379 plt_axis .set_yticklabels ([])
365380 plt_axis .set_xticklabels ([])
366381 plt_axis .grid (visible = False )
@@ -525,6 +540,161 @@ def visualize_image_attr_multiple(
525540 return plt_fig , plt_axis
526541
527542
543+ def _plot_attrs_as_axvspan (
544+ # pyre-fixme[2]: Parameter must be annotated.
545+ attr_vals ,
546+ # pyre-fixme[2]: Parameter must be annotated.
547+ x_vals ,
548+ # pyre-fixme[2]: Parameter must be annotated.
549+ ax ,
550+ # pyre-fixme[2]: Parameter must be annotated.
551+ x_values ,
552+ # pyre-fixme[2]: Parameter must be annotated.
553+ cmap ,
554+ # pyre-fixme[2]: Parameter must be annotated.
555+ cm_norm ,
556+ # pyre-fixme[2]: Parameter must be annotated.
557+ alpha_overlay ,
558+ ) -> None :
559+ # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
560+ half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
561+
562+ for icol , col_center in enumerate (x_vals ):
563+ left = col_center - half_col_width
564+ right = col_center + half_col_width
565+ ax .axvspan (
566+ xmin = left ,
567+ xmax = right ,
568+ # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
569+ facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
570+ edgecolor = None ,
571+ alpha = alpha_overlay ,
572+ )
573+
574+
575+ def _visualize_overlay_individual (
576+ # pyre-fixme[2]: Parameter must be annotated.
577+ num_channels ,
578+ # pyre-fixme[2]: Parameter must be annotated.
579+ plt_axis_list ,
580+ # pyre-fixme[2]: Parameter must be annotated.
581+ x_values ,
582+ # pyre-fixme[2]: Parameter must be annotated.
583+ data ,
584+ # pyre-fixme[2]: Parameter must be annotated.
585+ channel_labels ,
586+ # pyre-fixme[2]: Parameter must be annotated.
587+ norm_attr ,
588+ # pyre-fixme[2]: Parameter must be annotated.
589+ cmap ,
590+ # pyre-fixme[2]: Parameter must be annotated.
591+ cm_norm ,
592+ # pyre-fixme[2]: Parameter must be annotated.
593+ alpha_overlay ,
594+ # pyre-fixme[2]: Parameter must be annotated.
595+ ** kwargs : Any ,
596+ ) -> None :
597+ # helper method for visualize_timeseries_attr
598+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
599+ for chan in range (num_channels ):
600+ plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
601+ if channel_labels is not None :
602+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
603+
604+ _plot_attrs_as_axvspan (
605+ norm_attr [chan ],
606+ x_values ,
607+ plt_axis_list [chan ],
608+ x_values ,
609+ cmap ,
610+ cm_norm ,
611+ alpha_overlay ,
612+ )
613+
614+ plt .subplots_adjust (hspace = 0 )
615+ pass
616+
617+
618+ def _visualize_overlay_combined (
619+ # pyre-fixme[2]: Parameter must be annotated.
620+ num_channels ,
621+ # pyre-fixme[2]: Parameter must be annotated.
622+ plt_axis_list ,
623+ # pyre-fixme[2]: Parameter must be annotated.
624+ x_values ,
625+ # pyre-fixme[2]: Parameter must be annotated.
626+ data ,
627+ # pyre-fixme[2]: Parameter must be annotated.
628+ channel_labels ,
629+ # pyre-fixme[2]: Parameter must be annotated.
630+ norm_attr ,
631+ # pyre-fixme[2]: Parameter must be annotated.
632+ cmap ,
633+ # pyre-fixme[2]: Parameter must be annotated.
634+ cm_norm ,
635+ # pyre-fixme[2]: Parameter must be annotated.
636+ alpha_overlay ,
637+ ** kwargs : Any ,
638+ ) -> None :
639+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
640+
641+ cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ].colors ) # type: ignore
642+ plt_axis_list [0 ].set_prop_cycle (cycler )
643+
644+ for chan in range (num_channels ):
645+ label = channel_labels [chan ] if channel_labels else None
646+ plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
647+
648+ _plot_attrs_as_axvspan (
649+ norm_attr ,
650+ x_values ,
651+ plt_axis_list [0 ],
652+ x_values ,
653+ cmap ,
654+ cm_norm ,
655+ alpha_overlay ,
656+ )
657+
658+ plt_axis_list [0 ].legend (loc = "best" )
659+
660+
661+ def _visualize_colored_graph (
662+ # pyre-fixme[2]: Parameter must be annotated.
663+ num_channels ,
664+ # pyre-fixme[2]: Parameter must be annotated.
665+ plt_axis_list ,
666+ # pyre-fixme[2]: Parameter must be annotated.
667+ x_values ,
668+ # pyre-fixme[2]: Parameter must be annotated.
669+ data ,
670+ # pyre-fixme[2]: Parameter must be annotated.
671+ channel_labels ,
672+ # pyre-fixme[2]: Parameter must be annotated.
673+ norm_attr ,
674+ # pyre-fixme[2]: Parameter must be annotated.
675+ cmap ,
676+ # pyre-fixme[2]: Parameter must be annotated.
677+ cm_norm ,
678+ ** kwargs : Any ,
679+ ) -> None :
680+ # helper method for visualize_timeseries_attr
681+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
682+ for chan in range (num_channels ):
683+ points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
684+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
685+
686+ lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
687+ lc .set_array (norm_attr [chan , :])
688+ plt_axis_list [chan ].add_collection (lc )
689+ plt_axis_list [chan ].set_ylim (
690+ 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
691+ )
692+ if channel_labels is not None :
693+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
694+
695+ plt .subplots_adjust (hspace = 0 )
696+
697+
528698def visualize_timeseries_attr (
529699 attr : ndarray ,
530700 data : ndarray ,
@@ -683,8 +853,8 @@ def visualize_timeseries_attr(
683853
684854 num_subplots = num_channels
685855 if (
686- TimeseriesVisualizationMethod [method ]
687- == TimeseriesVisualizationMethod .overlay_combined
856+ TimeseriesVisualizationMethod [method ]. value
857+ == TimeseriesVisualizationMethod .overlay_combined . value
688858 ):
689859 num_subplots = 1
690860 attr = np .sum (attr , axis = 0 ) # Merge attributions across channels
@@ -697,17 +867,9 @@ def visualize_timeseries_attr(
697867 x_values = np .arange (timeseries_length )
698868
699869 # Create plot if figure, axis not provided
700- if plt_fig_axis is not None :
701- plt_fig , plt_axis = plt_fig_axis
702- else :
703- if use_pyplot :
704- plt_fig , plt_axis = plt .subplots ( # type: ignore
705- figsize = fig_size , nrows = num_subplots , sharex = True
706- )
707- else :
708- plt_fig = Figure (figsize = fig_size )
709- plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True ) # type: ignore
710- # Figure.subplots returns Axes or array of Axes
870+ plt_fig , plt_axis = _create_default_plot (
871+ plt_fig_axis , use_pyplot , fig_size , nrows = num_subplots , sharex = True
872+ )
711873
712874 if not isinstance (plt_axis , ndarray ):
713875 plt_axis_list = np .array ([plt_axis ])
@@ -717,91 +879,30 @@ def visualize_timeseries_attr(
717879 norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
718880
719881 # Set default colormap and bounds based on sign.
720- if VisualizeSign [sign ] == VisualizeSign .all :
721- default_cmap : Union [str , LinearSegmentedColormap ] = (
722- LinearSegmentedColormap .from_list ("RdWhGn" , ["red" , "white" , "green" ])
723- )
724- vmin , vmax = - 1 , 1
725- elif VisualizeSign [sign ] == VisualizeSign .positive :
726- default_cmap = "Greens"
727- vmin , vmax = 0 , 1
728- elif VisualizeSign [sign ] == VisualizeSign .negative :
729- default_cmap = "Reds"
730- vmin , vmax = 0 , 1
731- elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
732- default_cmap = "Blues"
733- vmin , vmax = 0 , 1
734- else :
735- raise AssertionError ("Visualize Sign type is not valid." )
882+ default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
736883 cmap = cmap if cmap is not None else default_cmap
737884 cmap = cm .get_cmap (cmap ) # type: ignore
738885 cm_norm = colors .Normalize (vmin , vmax )
739886
740- # pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
741- # pyre-fixme[2]: Parameter must be annotated.
742- def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ) -> None :
743- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
744- half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
745- for icol , col_center in enumerate (x_vals ):
746- left = col_center - half_col_width
747- right = col_center + half_col_width
748- ax .axvspan (
749- xmin = left ,
750- xmax = right ,
751- # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
752- facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
753- edgecolor = None ,
754- alpha = alpha_overlay ,
755- )
756-
757- if (
758- TimeseriesVisualizationMethod [method ]
759- == TimeseriesVisualizationMethod .overlay_individual
760- ):
761- for chan in range (num_channels ):
762- plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
763- if channel_labels is not None :
764- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
765-
766- _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis_list [chan ])
767-
768- plt .subplots_adjust (hspace = 0 )
769-
770- elif (
771- TimeseriesVisualizationMethod [method ]
772- == TimeseriesVisualizationMethod .overlay_combined
773- ):
774- # Dark colors are better in this case
775- cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ]) # type: ignore
776- plt_axis_list [0 ].set_prop_cycle (cycler )
777-
778- for chan in range (num_channels ):
779- label = channel_labels [chan ] if channel_labels else None
780- plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
781-
782- _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis_list [0 ])
783-
784- plt_axis_list [0 ].legend (loc = "best" )
785-
786- elif (
787- TimeseriesVisualizationMethod [method ]
788- == TimeseriesVisualizationMethod .colored_graph
789- ):
790- for chan in range (num_channels ):
791- points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
792- segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
793-
794- lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
795- lc .set_array (norm_attr [chan , :])
796- plt_axis_list [chan ].add_collection (lc )
797- plt_axis_list [chan ].set_ylim (
798- 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
799- )
800- if channel_labels is not None :
801- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
802-
803- plt .subplots_adjust (hspace = 0 )
804-
887+ visualization_methods : Dict [str , Callable [..., Union [None , AxesImage ]]] = {
888+ "overlay_individual" : _visualize_overlay_individual ,
889+ "overlay_combined" : _visualize_overlay_combined ,
890+ "colored_graph" : _visualize_colored_graph ,
891+ }
892+ kwargs = {
893+ "num_channels" : num_channels ,
894+ "plt_axis_list" : plt_axis_list ,
895+ "x_values" : x_values ,
896+ "data" : data ,
897+ "channel_labels" : channel_labels ,
898+ "norm_attr" : norm_attr ,
899+ "cmap" : cmap ,
900+ "cm_norm" : cm_norm ,
901+ "alpha_overlay" : alpha_overlay ,
902+ "pyplot_kwargs" : pyplot_kwargs ,
903+ }
904+ if method in visualization_methods :
905+ visualization_methods [method ](** kwargs )
805906 else :
806907 raise AssertionError ("Invalid visualization method: {}" .format (method ))
807908
0 commit comments