|
3 | 3 | # pyre-strict |
4 | 4 | import warnings |
5 | 5 | from enum import Enum |
6 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
| 6 | +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import matplotlib |
9 | 9 |
|
@@ -444,7 +444,7 @@ def visualize_image_attr_multiple( |
444 | 444 | fig_size: Tuple[int, int] = (8, 6), |
445 | 445 | use_pyplot: bool = True, |
446 | 446 | **kwargs: Any, |
447 | | -) -> Tuple[Figure, Axes]: |
| 447 | +) -> Tuple[Figure, Union[Axes, List[Axes]]]: |
448 | 448 | r""" |
449 | 449 | Visualizes attribution using multiple visualization methods displayed |
450 | 450 | in a 1 x k grid, where k is the number of desired visualizations. |
@@ -516,15 +516,19 @@ def visualize_image_attr_multiple( |
516 | 516 | plt_fig = plt.figure(figsize=fig_size) |
517 | 517 | else: |
518 | 518 | plt_fig = Figure(figsize=fig_size) |
519 | | - plt_axis = plt_fig.subplots(1, len(methods)) |
| 519 | + plt_axis_np = plt_fig.subplots(1, len(methods), squeeze=True) |
520 | 520 |
|
| 521 | + plt_axis: Union[Axes, List[Axes]] |
521 | 522 | plt_axis_list: List[Axes] = [] |
522 | 523 | # When visualizing one |
523 | 524 | if len(methods) == 1: |
524 | | - plt_axis_list = [plt_axis] # type: ignore |
| 525 | + plt_axis = cast(Axes, plt_axis_np) |
| 526 | + plt_axis_list = [plt_axis] |
525 | 527 | # Figure.subplots returns Axes or array of Axes |
526 | 528 | else: |
527 | | - plt_axis_list = plt_axis # type: ignore |
| 529 | + # https://github.com/numpy/numpy/issues/24738 |
| 530 | + plt_axis = cast(List[Axes], cast(npt.NDArray, plt_axis_np).tolist()) |
| 531 | + plt_axis_list = plt_axis |
528 | 532 | # Figure.subplots returns Axes or array of Axes |
529 | 533 |
|
530 | 534 | for i in range(len(methods)): |
|
0 commit comments