diff --git a/.conda/meta.yaml b/.conda/meta.yaml index c05884ec6a..5217866af7 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -13,10 +13,10 @@ build: requirements: host: - - python>=3.6 + - python>=3.8 run: - - numpy - - pytorch>=1.6 + - numpy<2.0 + - pytorch>=1.10 - matplotlib-base test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index daa676d5c1..ab4d71bc6c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.12xlarge - docker-image: cimg/python:3.9 + docker-image: cimg/python:3.11 repository: pytorch/captum script: | sudo chmod -R 777 . diff --git a/.github/workflows/test-conda-cpu.yml b/.github/workflows/test-conda-cpu.yml index 9ae71846d1..3295edcca8 100644 --- a/.github/workflows/test-conda-cpu.yml +++ b/.github/workflows/test-conda-cpu.yml @@ -15,7 +15,7 @@ jobs: tests: strategy: matrix: - python_version: ["3.7", "3.8", "3.9", "3.10"] + python_version: ["3.8", "3.9", "3.10", "3.11"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: diff --git a/.github/workflows/test-pip-cpu-with-mypy.yml b/.github/workflows/test-pip-cpu-with-mypy.yml index 4a721cb85b..7e166261e4 100644 --- a/.github/workflows/test-pip-cpu-with-mypy.yml +++ b/.github/workflows/test-pip-cpu-with-mypy.yml @@ -17,7 +17,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.12xlarge - docker-image: cimg/python:3.6 + docker-image: cimg/python:3.11 repository: pytorch/captum script: | sudo chmod -R 777 . diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index 1db8e5769b..2acacdc0a9 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -12,18 +12,17 @@ jobs: tests: strategy: matrix: - pytorch_args: ["-v 1.6", "-v 1.7", "-v 1.8", "-v 1.9", "-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13"] - docker_img: ["cimg/python:3.6", "cimg/python:3.7"] - include: - - pytorch_args: "-v 2.0" - docker_img: "cimg/python:3.8" + pytorch_args: ["-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0"] + docker_img: ["cimg/python:3.8", "cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11"] exclude: + - pytorch_args: "-v 1.10" + docker_img: "cimg/python:3.10" + - pytorch_args: "-v 1.10" + docker_img: "cimg/python:3.11" - pytorch_args: "-v 1.11" - docker_img: "cimg/python:3.6" + docker_img: "cimg/python:3.11" - pytorch_args: "-v 1.12" - docker_img: "cimg/python:3.6" - - pytorch_args: "-v 1.13" - docker_img: "cimg/python:3.6" + docker_img: "cimg/python:3.11" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: diff --git a/README.md b/README.md index b75fdcd025..a78f996561 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,8 @@ Captum can also be used by application engineers who are using trained models in ## Installation **Installation Requirements** -- Python >= 3.6 -- PyTorch >= 1.6 +- Python >= 3.8 +- PyTorch >= 1.10 ##### Installing the latest release diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 89761b9fd0..262fb29f21 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -733,7 +733,7 @@ def _compute_jacobian_wrt_params( inputs: Tuple[Any, ...], labels: Optional[Tensor] = None, loss_fn: Optional[Union[Module, Callable]] = None, - layer_modules: List[Module] = None, + layer_modules: Optional[List[Module]] = None, ) -> Tuple[Tensor, ...]: r""" Computes the Jacobian of a batch of test examples given a model, and optional @@ -805,7 +805,7 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( labels: Optional[Tensor] = None, loss_fn: Optional[Union[Module, Callable]] = None, reduction_type: Optional[str] = "sum", - layer_modules: List[Module] = None, + layer_modules: Optional[List[Module]] = None, ) -> Tuple[Any, ...]: r""" Computes the Jacobian of a batch of test examples given a model, and optional diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index 10ad1bc27a..cb08a15aed 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -3,7 +3,7 @@ import sys import warnings from time import time -from typing import cast, Iterable, Sized, TextIO +from typing import cast, Iterable, Optional, Sized, TextIO from captum._utils.typing import Literal @@ -51,7 +51,7 @@ class NullProgress: progress bars. """ - def __init__(self, iterable: Iterable = None, *args, **kwargs): + def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs): del args, kwargs self.iterable = iterable @@ -77,10 +77,10 @@ def close(self): class SimpleProgress: def __init__( self, - iterable: Iterable = None, - desc: str = None, - total: int = None, - file: TextIO = None, + iterable: Optional[Iterable] = None, + desc: Optional[str] = None, + total: Optional[int] = None, + file: Optional[TextIO] = None, mininterval: float = 0.5, ) -> None: """ @@ -155,11 +155,11 @@ def close(self): def progress( - iterable: Iterable = None, - desc: str = None, - total: int = None, + iterable: Optional[Iterable] = None, + desc: Optional[str] = None, + total: Optional[int] = None, use_tqdm=True, - file: TextIO = None, + file: Optional[TextIO] = None, mininterval: float = 0.5, **kwargs, ): diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index fdda434743..d16cec47a1 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 from enum import Enum -from typing import Any, cast, List, Tuple, Union +from typing import Any, cast, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -80,7 +80,7 @@ def attribute( inputs: Union[Tensor, Tuple[Tensor, ...]], nt_type: str = "smoothgrad", nt_samples: int = 5, - nt_samples_batch_size: int = None, + nt_samples_batch_size: Optional[int] = None, stdevs: Union[float, Tuple[float, ...]] = 1.0, draw_baseline_from_distrib: bool = False, **kwargs: Any, diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index f54cdde4ba..6e7d8d8326 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -45,7 +46,7 @@ def _scatter_itp_attr_by_mask( return attr -class InterpretableInput: +class InterpretableInput(ABC): """ InterpretableInput is an adapter for different kinds of model inputs to work in Captum's attribution methods. Generally, attribution methods of Captum @@ -94,6 +95,7 @@ class to create other types of customized input. is only allowed in certain attribution classes like LLMAttribution for now.) """ + @abstractmethod def to_tensor(self) -> Tensor: """ Return the interpretable representation of this input as a tensor @@ -104,6 +106,7 @@ def to_tensor(self) -> Tensor: """ pass + @abstractmethod def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any: """ Get the (perturbed) input in the format required by the model diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index f1319ea98f..e00714a8c5 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -3,12 +3,14 @@ from enum import Enum from typing import Any, Iterable, List, Optional, Tuple, Union +import matplotlib + import numpy as np from matplotlib import cm, colors, pyplot as plt +from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from matplotlib.colors import LinearSegmentedColormap +from matplotlib.colors import Colormap, LinearSegmentedColormap from matplotlib.figure import Figure -from matplotlib.pyplot import axis, figure from mpl_toolkits.axes_grid1 import make_axes_locatable from numpy import ndarray @@ -51,7 +53,8 @@ def _normalize_scale(attr: ndarray, scale_factor: float): warnings.warn( "Attempting to normalize by value approximately 0, visualized results" "may be misleading. This likely means that attribution values are all" - "close to 0." + "close to 0.", + stacklevel=2, ) attr_norm = attr / scale_factor return np.clip(attr_norm, -1, 1) @@ -80,18 +83,20 @@ def _normalize_attr( # Choose appropriate signed values and rescale, removing given outlier percentage. if VisualizeSign[sign] == VisualizeSign.all: - threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc) + threshold = _cumulative_sum_threshold( + np.abs(attr_combined), 100.0 - outlier_perc + ) elif VisualizeSign[sign] == VisualizeSign.positive: attr_combined = (attr_combined > 0) * attr_combined - threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) + threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc) elif VisualizeSign[sign] == VisualizeSign.negative: attr_combined = (attr_combined < 0) * attr_combined threshold = -1 * _cumulative_sum_threshold( - np.abs(attr_combined), 100 - outlier_perc + np.abs(attr_combined), 100.0 - outlier_perc ) elif VisualizeSign[sign] == VisualizeSign.absolute_value: attr_combined = np.abs(attr_combined) - threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) + threshold = _cumulative_sum_threshold(attr_combined, 100.0 - outlier_perc) else: raise AssertionError("Visualize Sign type is not valid.") return _normalize_scale(attr_combined, threshold) @@ -99,18 +104,18 @@ def _normalize_attr( def visualize_image_attr( attr: ndarray, - original_image: Union[None, ndarray] = None, + original_image: Optional[ndarray] = None, method: str = "heat_map", sign: str = "absolute_value", - plt_fig_axis: Union[None, Tuple[figure, axis]] = None, + plt_fig_axis: Optional[Tuple[Figure, Axes]] = None, outlier_perc: Union[int, float] = 2, - cmap: Union[None, str] = None, + cmap: Optional[Union[str, Colormap]] = None, alpha_overlay: float = 0.5, show_colorbar: bool = False, - title: Union[None, str] = None, + title: Optional[str] = None, fig_size: Tuple[int, int] = (6, 6), use_pyplot: bool = True, -): +) -> Tuple[Figure, Axes]: r""" Visualizes attribution for a given image by normalizing attribution values of the desired sign (positive, negative, absolute value, or all) and displaying @@ -231,7 +236,8 @@ def visualize_image_attr( plt_fig, plt_axis = plt.subplots(figsize=fig_size) else: plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots() + plt_axis = plt_fig.subplots() # type: ignore + # Figure.subplots returns Axes or array of Axes if original_image is not None: if np.max(original_image) <= 1.0: @@ -264,8 +270,8 @@ def visualize_image_attr( # Set default colormap and bounds based on sign. if VisualizeSign[sign] == VisualizeSign.all: - default_cmap = LinearSegmentedColormap.from_list( - "RdWhGn", ["red", "white", "green"] + default_cmap: Union[str, LinearSegmentedColormap] = ( + LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"]) ) vmin, vmax = -1, 1 elif VisualizeSign[sign] == VisualizeSign.positive: @@ -345,7 +351,7 @@ def visualize_image_attr_multiple( original_image: Union[None, ndarray], methods: List[str], signs: List[str], - titles: Union[None, List[str]] = None, + titles: Optional[List[str]] = None, fig_size: Tuple[int, int] = (8, 6), use_pyplot: bool = True, **kwargs: Any, @@ -423,9 +429,14 @@ def visualize_image_attr_multiple( plt_fig = Figure(figsize=fig_size) plt_axis = plt_fig.subplots(1, len(methods)) + plt_axis_list: List[Axes] = [] # When visualizing one if len(methods) == 1: - plt_axis = [plt_axis] + plt_axis_list = [plt_axis] # type: ignore + # Figure.subplots returns Axes or array of Axes + else: + plt_axis_list = plt_axis # type: ignore + # Figure.subplots returns Axes or array of Axes for i in range(len(methods)): visualize_image_attr( @@ -433,7 +444,7 @@ def visualize_image_attr_multiple( original_image=original_image, method=methods[i], sign=signs[i], - plt_fig_axis=(plt_fig, plt_axis[i]), + plt_fig_axis=(plt_fig, plt_axis_list[i]), use_pyplot=False, title=titles[i] if titles else None, **kwargs, @@ -452,12 +463,12 @@ def visualize_timeseries_attr( sign: str = "absolute_value", channel_labels: Optional[List[str]] = None, channels_last: bool = True, - plt_fig_axis: Union[None, Tuple[figure, axis]] = None, + plt_fig_axis: Optional[Tuple[Figure, Union[Axes, List[Axes]]]] = None, outlier_perc: Union[int, float] = 2, - cmap: Union[None, str] = None, + cmap: Optional[Union[str, Colormap]] = None, alpha_overlay: float = 0.7, show_colorbar: bool = False, - title: Union[None, str] = None, + title: Optional[str] = None, fig_size: Tuple[int, int] = (6, 6), use_pyplot: bool = True, **pyplot_kwargs, @@ -596,7 +607,8 @@ def visualize_timeseries_attr( if num_channels > timeseries_length: warnings.warn( "Number of channels ({}) greater than time series length ({}), " - "please verify input format".format(num_channels, timeseries_length) + "please verify input format".format(num_channels, timeseries_length), + stacklevel=2, ) num_subplots = num_channels @@ -624,17 +636,20 @@ def visualize_timeseries_attr( ) else: plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) + plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore + # Figure.subplots returns Axes or array of Axes if not isinstance(plt_axis, ndarray): - plt_axis = np.array([plt_axis]) + plt_axis_list = np.array([plt_axis]) + else: + plt_axis_list = plt_axis norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None) # Set default colormap and bounds based on sign. if VisualizeSign[sign] == VisualizeSign.all: - default_cmap = LinearSegmentedColormap.from_list( - "RdWhGn", ["red", "white", "green"] + default_cmap: Union[str, LinearSegmentedColormap] = ( + LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"]) ) vmin, vmax = -1, 1 elif VisualizeSign[sign] == VisualizeSign.positive: @@ -649,11 +664,10 @@ def visualize_timeseries_attr( else: raise AssertionError("Visualize Sign type is not valid.") cmap = cmap if cmap is not None else default_cmap - cmap = cm.get_cmap(cmap) + cmap = cm.get_cmap(cmap) # type: ignore cm_norm = colors.Normalize(vmin, vmax) def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): - half_col_width = (x_values[1] - x_values[0]) / 2.0 for icol, col_center in enumerate(x_vals): left = col_center - half_col_width @@ -670,14 +684,12 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): TimeseriesVisualizationMethod[method] == TimeseriesVisualizationMethod.overlay_individual ): - for chan in range(num_channels): - - plt_axis[chan].plot(x_values, data[chan, :], **pyplot_kwargs) + plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs) if channel_labels is not None: - plt_axis[chan].set_ylabel(channel_labels[chan]) + plt_axis_list[chan].set_ylabel(channel_labels[chan]) - _plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis[chan]) + _plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan]) plt.subplots_adjust(hspace=0) @@ -685,37 +697,34 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): TimeseriesVisualizationMethod[method] == TimeseriesVisualizationMethod.overlay_combined ): - # Dark colors are better in this case - cycler = plt.cycler("color", cm.Dark2.colors) - plt_axis[0].set_prop_cycle(cycler) + cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore + plt_axis_list[0].set_prop_cycle(cycler) for chan in range(num_channels): label = channel_labels[chan] if channel_labels else None - plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs) + plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs) - _plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0]) + _plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0]) - plt_axis[0].legend(loc="best") + plt_axis_list[0].legend(loc="best") elif ( TimeseriesVisualizationMethod[method] == TimeseriesVisualizationMethod.colored_graph ): - for chan in range(num_channels): - points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs) lc.set_array(norm_attr[chan, :]) - plt_axis[chan].add_collection(lc) - plt_axis[chan].set_ylim( + plt_axis_list[chan].add_collection(lc) + plt_axis_list[chan].set_ylim( 1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :]) ) if channel_labels is not None: - plt_axis[chan].set_ylabel(channel_labels[chan]) + plt_axis_list[chan].set_ylabel(channel_labels[chan]) plt.subplots_adjust(hspace=0) @@ -725,7 +734,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): plt.xlim([x_values[0], x_values[-1]]) if show_colorbar: - axis_separator = make_axes_locatable(plt_axis[-1]) + axis_separator = make_axes_locatable(plt_axis_list[-1]) colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.4) colorbar_alpha = alpha_overlay if ( @@ -740,7 +749,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): alpha=colorbar_alpha, ) if title: - plt_axis[0].set_title(title) + plt_axis_list[0].set_title(title) if use_pyplot: plt.show() diff --git a/captum/concept/_core/cav.py b/captum/concept/_core/cav.py index 6aedb24fff..1b025c5b3c 100644 --- a/captum/concept/_core/cav.py +++ b/captum/concept/_core/cav.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch from captum.concept._core.concept import Concept @@ -21,7 +21,7 @@ def __init__( self, concepts: List[Concept], layer: str, - stats: Dict[str, Any] = None, + stats: Optional[Dict[str, Any]] = None, save_path: str = "./cav/", model_id: str = "default_model_id", ) -> None: diff --git a/captum/concept/_core/tcav.py b/captum/concept/_core/tcav.py index 3d9ca0fdf4..bc33826f3b 100644 --- a/captum/concept/_core/tcav.py +++ b/captum/concept/_core/tcav.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from collections import defaultdict -from typing import Any, cast, Dict, List, Set, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -247,8 +247,8 @@ def __init__( model: Module, layers: Union[str, List[str]], model_id: str = "default_model_id", - classifier: Classifier = None, - layer_attr_method: LayerAttribution = None, + classifier: Optional[Classifier] = None, + layer_attr_method: Optional[LayerAttribution] = None, attribute_to_layer_input=False, save_path: str = "./cav/", **classifier_kwargs: Any, @@ -451,7 +451,7 @@ def compute_cavs( self, experimental_sets: List[List[Concept]], force_train: bool = False, - processes: int = None, + processes: Optional[int] = None, ): r""" This method computes CAVs for given `experiments_sets` and layers @@ -567,7 +567,7 @@ def interpret( experimental_sets: List[List[Concept]], target: TargetType = None, additional_forward_args: Any = None, - processes: int = None, + processes: Optional[int] = None, **kwargs: Any, ) -> Dict[str, Dict[str, Dict[str, Tensor]]]: r""" @@ -747,7 +747,7 @@ def interpret( attribs, cav_subset, classes_subset, - experimental_subset_sorted, + experimental_subset_sorted, # type: ignore ) i += 1 diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index 32b44506ba..459e0743bc 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -876,7 +876,7 @@ def __init__( test_loss_fn: Optional[Union[Module, Callable]] = None, vectorize: bool = False, nearest_neighbors: Optional[NearestNeighbors] = None, - projection_dim: int = None, + projection_dim: Optional[int] = None, seed: int = 0, ) -> None: r""" diff --git a/captum/insights/attr_vis/features.py b/captum/insights/attr_vis/features.py index fac17f8e80..6db140b392 100644 --- a/captum/insights/attr_vis/features.py +++ b/captum/insights/attr_vis/features.py @@ -8,12 +8,14 @@ from captum._utils.common import safe_div from captum.attr._utils import visualization as viz from captum.insights.attr_vis._utils.transforms import format_transforms +from matplotlib.figure import Figure from torch import Tensor + FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution") -def _convert_figure_base64(fig) -> str: +def _convert_figure_base64(fig: Figure) -> str: buff = BytesIO() with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -82,8 +84,8 @@ class ImageFeature(BaseFeature): def __init__( self, name: str, - baseline_transforms: Union[Callable, List[Callable]], - input_transforms: Union[Callable, List[Callable]], + baseline_transforms: Optional[Union[Callable, List[Callable]]], + input_transforms: Optional[Union[Callable, List[Callable]]], visualization_transform: Optional[Callable] = None, ) -> None: r""" @@ -157,9 +159,9 @@ class TextFeature(BaseFeature): def __init__( self, name: str, - baseline_transforms: Union[Callable, List[Callable]], - input_transforms: Union[Callable, List[Callable]], - visualization_transform: Callable, + baseline_transforms: Optional[Union[Callable, List[Callable]]], + input_transforms: Optional[Union[Callable, List[Callable]]], + visualization_transform: Optional[Callable], ) -> None: r""" Args: @@ -295,7 +297,7 @@ def __init__( def visualization_type() -> str: return "empty" - def visualize(self, _attribution, _data, contribution_frac) -> FeatureOutput: + def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: return FeatureOutput( name=self.name, base=None, diff --git a/captum/insights/attr_vis/widget/widget.py b/captum/insights/attr_vis/widget/widget.py index d9629d78cb..b1ed8c9ee4 100644 --- a/captum/insights/attr_vis/widget/widget.py +++ b/captum/insights/attr_vis/widget/widget.py @@ -22,7 +22,7 @@ class CaptumInsights(widgets.DOMWidget): label_details = Dict().tag(sync=True) attribution = Dict().tag(sync=True) config = Dict().tag(sync=True) - output = List().tag(sync=True) + output = List().tag(sync=True) # type: ignore def __init__(self, **kwargs) -> None: super(CaptumInsights, self).__init__(**kwargs) diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index e887e7e440..d80f28291a 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from typing import Any, Callable, cast, Tuple, Union +from typing import Any, Callable, cast, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -121,7 +121,7 @@ def infidelity( additional_forward_args: Any = None, target: TargetType = None, n_perturb_samples: int = 10, - max_examples_per_batch: int = None, + max_examples_per_batch: Optional[int] = None, normalize: bool = False, ) -> Tensor: r""" diff --git a/captum/metrics/_core/sensitivity.py b/captum/metrics/_core/sensitivity.py index d7e0ab54bc..f50eaf74e1 100644 --- a/captum/metrics/_core/sensitivity.py +++ b/captum/metrics/_core/sensitivity.py @@ -2,7 +2,7 @@ from copy import deepcopy from inspect import signature -from typing import Any, Callable, cast, Tuple, Union +from typing import Any, Callable, cast, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -64,7 +64,7 @@ def sensitivity_max( perturb_radius: float = 0.02, n_perturb_samples: int = 10, norm_ord: str = "fro", - max_examples_per_batch: int = None, + max_examples_per_batch: Optional[int] = None, **kwargs: Any, ) -> Tensor: r""" diff --git a/captum/metrics/_utils/batching.py b/captum/metrics/_utils/batching.py index 83a773bda3..40e1efd784 100644 --- a/captum/metrics/_utils/batching.py +++ b/captum/metrics/_utils/batching.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import warnings -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import torch from torch import Tensor @@ -12,7 +12,7 @@ def _divide_and_aggregate_metrics( n_perturb_samples: int, metric_func: Callable, agg_func: Callable = torch.add, - max_examples_per_batch: int = None, + max_examples_per_batch: Optional[int] = None, ) -> Tensor: r""" This function is used to slice large number of samples `n_perturb_samples` per diff --git a/captum/robust/_core/pgd.py b/captum/robust/_core/pgd.py index 5391b39cfb..c926a55345 100644 --- a/captum/robust/_core/pgd.py +++ b/captum/robust/_core/pgd.py @@ -38,7 +38,7 @@ class PGD(Perturbation): def __init__( self, forward_func: Callable, - loss_func: Callable = None, + loss_func: Optional[Callable] = None, lower_bound: float = float("-inf"), upper_bound: float = float("inf"), ) -> None: diff --git a/environment.yml b/environment.yml index 61de9e0096..e2a5275b45 100644 --- a/environment.yml +++ b/environment.yml @@ -2,5 +2,5 @@ name: captum channels: - pytorch dependencies: - - numpy - - pytorch>=1.6 + - numpy<2.0 + - pytorch>=1.10 diff --git a/setup.py b/setup.py index 31b55c485f..9744eeaae9 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from setuptools import find_packages, setup REQUIRED_MAJOR = 3 -REQUIRED_MINOR = 6 +REQUIRED_MINOR = 8 # Check for python version if sys.version_info < (REQUIRED_MAJOR, REQUIRED_MINOR): @@ -146,8 +146,8 @@ def get_package_files(root, subdirs): ], long_description=long_description, long_description_content_type="text/markdown", - python_requires=">=3.6", - install_requires=["matplotlib", "numpy", "torch>=1.6", "tqdm"], + python_requires=">=3.8", + install_requires=["matplotlib", "numpy<2.0", "torch>=1.10", "tqdm"], packages=find_packages(exclude=("tests", "tests.*")), extras_require={ "dev": DEV_REQUIRES, diff --git a/tests/attr/layer/test_layer_gradient_shap.py b/tests/attr/layer/test_layer_gradient_shap.py index 6c57d67d3b..34344f7a06 100644 --- a/tests/attr/layer/test_layer_gradient_shap.py +++ b/tests/attr/layer/test_layer_gradient_shap.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import unittest -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch @@ -175,7 +175,7 @@ def _assert_attributions( Tuple[List[float], ...], Tuple[List[List[float]], ...], ], - expected_delta: Tensor = None, + expected_delta: Optional[Tensor] = None, n_samples: int = 5, attribute_to_layer_input: bool = False, add_args: Any = None, diff --git a/tests/attr/models/test_pytext.py b/tests/attr/models/test_pytext.py index fa7394945b..4f3b3cb9f6 100644 --- a/tests/attr/models/test_pytext.py +++ b/tests/attr/models/test_pytext.py @@ -45,7 +45,7 @@ def __init__(self) -> None: class TestWordEmbeddings(unittest.TestCase): def setUp(self) -> None: if not HAS_PYTEXT: - return self.skipTest("Skip the test since PyText is not installed") + raise unittest.SkipTest("Skip the test since PyText is not installed") self.embedding_file, self.embedding_path = tempfile.mkstemp() self.word_embedding_file, self.word_embedding_path = tempfile.mkstemp() diff --git a/tests/attr/neuron/test_neuron_gradient.py b/tests/attr/neuron/test_neuron_gradient.py index d14b56eaa6..3d79a35ca5 100644 --- a/tests/attr/neuron/test_neuron_gradient.py +++ b/tests/attr/neuron/test_neuron_gradient.py @@ -141,7 +141,7 @@ def _gradient_matching_test_assert( while len(neuron) < len(out.shape) - 1: neuron = neuron + (0,) input_attrib = Saliency( - lambda x: _forward_layer_eval( + lambda x, neuron=neuron: _forward_layer_eval( model, x, output_layer, grad_enabled=True )[0][(slice(None), *neuron)] ) diff --git a/tests/attr/test_deeplift_basic.py b/tests/attr/test_deeplift_basic.py index 09499f4348..6017409a27 100644 --- a/tests/attr/test_deeplift_basic.py +++ b/tests/attr/test_deeplift_basic.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from inspect import signature -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from captum.attr._core.deep_lift import DeepLift, DeepLiftShap @@ -301,7 +301,7 @@ def _deeplift_assert( attr_method: Union[DeepLift, DeepLiftShap], inputs: Tuple[Tensor, ...], baselines, - custom_attr_func: Callable[..., Tuple[Tensor, ...]] = None, + custom_attr_func: Optional[Callable[..., Tuple[Tensor, ...]]] = None, ) -> None: input_bsz = len(inputs[0]) if callable(baselines): diff --git a/tests/attr/test_integrated_gradients_basic.py b/tests/attr/test_integrated_gradients_basic.py index 4b0ce2aa81..a269733314 100644 --- a/tests/attr/test_integrated_gradients_basic.py +++ b/tests/attr/test_integrated_gradients_basic.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest -from typing import Any, cast, Tuple, Union +from typing import Any, cast, Optional, Tuple, Union import torch from captum._utils.common import _zeros @@ -337,7 +337,7 @@ def _assert_batched_tensor_multi_input( self, type: str, approximation_method: str = "gausslegendre", - nt_samples_batch_size: int = None, + nt_samples_batch_size: Optional[int] = None, ) -> None: model = BasicModel_MultiLayer() input = ( @@ -361,7 +361,7 @@ def _assert_n_samples_batched_size( self, type: str, approximation_method: str = "gausslegendre", - nt_samples_batch_size: int = None, + nt_samples_batch_size: Optional[int] = None, ) -> None: model = BasicModel_MultiLayer() input = ( diff --git a/tests/helpers/influence/common.py b/tests/helpers/influence/common.py index 9f2df072c6..fcaa068865 100644 --- a/tests/helpers/influence/common.py +++ b/tests/helpers/influence/common.py @@ -424,19 +424,19 @@ def get_random_model_and_data( # turn input into a single tensor for use by least squares tensor_hessian_samples = ( - hessian_dataset.samples + hessian_dataset.samples # type: ignore if not unpack_inputs - else torch.cat(hessian_dataset.samples, dim=1) + else torch.cat(hessian_dataset.samples, dim=1) # type: ignore ) version = _parse_version(torch.__version__) if version < (1, 9): - theta = torch.lstsq( - tensor_hessian_samples, hessian_dataset.labels + theta = torch.lstsq( # type: ignore + tensor_hessian_samples, hessian_dataset.labels # type: ignore ).solution[0:1] else: # run least squares to get optimal trained parameters theta = torch.linalg.lstsq( - hessian_dataset.labels, + hessian_dataset.labels, # type: ignore tensor_hessian_samples, ).solution # the first `n` rows of `theta` contains the least squares solution, where @@ -459,7 +459,7 @@ def get_random_model_and_data( net_adjusted = _adjust_model(net, gpu_setting) # train model using several optimization steps on Hessian data - batch = next(iter(DataLoader(hessian_dataset, batch_size=len(hessian_dataset)))) + batch = next(iter(DataLoader(hessian_dataset, batch_size=len(hessian_dataset)))) # type: ignore # noqa: E501 line too long optimizer = torch.optim.Adam(net.parameters()) num_steps = 200 @@ -480,9 +480,9 @@ def get_random_model_and_data( train_dataset, ) - hessian_data = (hessian_dataset.samples, hessian_dataset.labels) + hessian_data = (hessian_dataset.samples, hessian_dataset.labels) # type: ignore - test_data = (test_dataset.samples, test_dataset.labels) + test_data = (test_dataset.samples, test_dataset.labels) # type: ignore if return_test_data: if not return_hessian_data: diff --git a/tests/insights/test_contribution.py b/tests/insights/test_contribution.py index cf8f2b8aff..264b330f62 100644 --- a/tests/insights/test_contribution.py +++ b/tests/insights/test_contribution.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest -from typing import Callable, List, Union +from typing import Any, Callable, Generator, List, Tuple, Union import torch import torch.nn as nn @@ -9,6 +9,8 @@ from captum.insights.attr_vis.app import FilterConfig from captum.insights.attr_vis.features import BaseFeature, FeatureOutput, ImageFeature from tests.helpers import BaseTest +from torch import Tensor +from torch.utils.data import DataLoader class RealFeature(BaseFeature): @@ -26,7 +28,8 @@ def __init__( visualization_transform=None, ) - def visualization_type(self) -> str: + @staticmethod + def visualization_type() -> str: return "real" def visualize(self, attribution, data, contribution_frac) -> FeatureOutput: @@ -103,7 +106,7 @@ def _labelled_img_data( height: int = 8, depth: int = 3, num_labels: int = 10, -): +) -> Generator[Tuple[Tensor, Tensor], Any, Any]: for _ in range(num_samples): yield torch.empty(depth, height, width).uniform_(0, 1), torch.randint( num_labels, (1,) @@ -150,8 +153,8 @@ def test_one_feature(self) -> None: # NOTE: using DataLoader to batch the inputs # since AttributionVisualizer requires the input to be of size `B x ...` - data_loader = torch.utils.data.DataLoader( - list(dataset), batch_size=batch_size, shuffle=False, num_workers=0 + data_loader: DataLoader = torch.utils.data.DataLoader( + list(dataset), batch_size=batch_size, shuffle=False, num_workers=0 # type: ignore # noqa: E501 line too long ) visualizer = AttributionVisualizer( @@ -188,8 +191,8 @@ def test_multi_features(self) -> None: ) # NOTE: using DataLoader to batch the inputs since # AttributionVisualizer requires the input to be of size `N x ...` - data_loader = torch.utils.data.DataLoader( - list(dataset), batch_size=batch_size, shuffle=False, num_workers=0 + data_loader: DataLoader = torch.utils.data.DataLoader( + list(dataset), batch_size=batch_size, shuffle=False, num_workers=0 # type: ignore # noqa: E501 line too long ) visualizer = AttributionVisualizer( diff --git a/tests/insights/test_features.py b/tests/insights/test_features.py index 917249189f..76b6d9e05b 100644 --- a/tests/insights/test_features.py +++ b/tests/insights/test_features.py @@ -17,7 +17,12 @@ class TestTextFeature(BaseTest): FEATURE_NAME = "question" def test_text_feature_returns_text_as_visualization_type(self) -> None: - feature = TextFeature(self.FEATURE_NAME, None, None, None) + feature = TextFeature( + name=self.FEATURE_NAME, + baseline_transforms=None, + input_transforms=None, + visualization_transform=None, + ) self.assertEqual(feature.visualization_type(), "text") def test_text_feature_uses_visualization_transform_if_provided(self) -> None: diff --git a/tests/metrics/test_infidelity.py b/tests/metrics/test_infidelity.py index 3a0da03553..bba5933dd2 100644 --- a/tests/metrics/test_infidelity.py +++ b/tests/metrics/test_infidelity.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import typing -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -395,11 +395,11 @@ def basic_model_assert( model: Module, inputs: TensorOrTupleOfTensorsGeneric, expected: Tensor, - n_perturb_samples: int = 10, - max_batch_size: int = None, - perturb_func: Callable = _local_perturb_func, - multiply_by_inputs: bool = False, - normalize: bool = False, + n_perturb_samples: Optional[int] = 10, + max_batch_size: Optional[int] = None, + perturb_func: Optional[Callable] = _local_perturb_func, + multiply_by_inputs: Optional[bool] = False, + normalize: Optional[bool] = False, ) -> Tensor: ig = IntegratedGradients(model) if multiply_by_inputs: @@ -429,12 +429,12 @@ def basic_model_global_assert( model: Module, inputs: TensorOrTupleOfTensorsGeneric, expected: Tensor, - additional_args: Any = None, - target: TargetType = None, - n_perturb_samples: int = 10, - max_batch_size: int = None, - perturb_func: Callable = _global_perturb_func1, - normalize: bool = False, + additional_args: Optional[Any] = None, + target: Optional[TargetType] = None, + n_perturb_samples: Optional[int] = 10, + max_batch_size: Optional[int] = None, + perturb_func: Optional[Callable] = _global_perturb_func1, + normalize: Optional[bool] = False, ) -> Tensor: attrs = attr_algo.attribute( inputs, additional_forward_args=additional_args, target=target @@ -459,14 +459,14 @@ def infidelity_assert( attributions: TensorOrTupleOfTensorsGeneric, inputs: TensorOrTupleOfTensorsGeneric, expected: Tensor, - additional_args: Any = None, - baselines: BaselineType = None, - n_perturb_samples: int = 10, - target: TargetType = None, - max_batch_size: int = None, - multi_input: bool = True, - perturb_func: Callable = _local_perturb_func, - normalize: bool = False, + additional_args: Optional[Any] = None, + baselines: Optional[BaselineType] = None, + n_perturb_samples: Optional[int] = 10, + target: Optional[TargetType] = None, + max_batch_size: Optional[int] = None, + multi_input: Optional[bool] = True, + perturb_func: Optional[Callable] = _local_perturb_func, + normalize: Optional[bool] = False, **kwargs: Any, ) -> Tensor: infid = infidelity( diff --git a/tests/metrics/test_sensitivity.py b/tests/metrics/test_sensitivity.py index 2152e0d7ef..81f4d32d21 100644 --- a/tests/metrics/test_sensitivity.py +++ b/tests/metrics/test_sensitivity.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import typing -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -281,10 +281,10 @@ def sensitivity_max_assert( expected_sensitivity: Tensor, perturb_func: Callable = _perturb_func, n_perturb_samples: int = 5, - max_examples_per_batch: int = None, - baselines: BaselineType = None, - target: TargetType = None, - additional_forward_args: Any = None, + max_examples_per_batch: Optional[int] = None, + baselines: Optional[BaselineType] = None, + target: Optional[TargetType] = None, + additional_forward_args: Optional[Any] = None, ) -> Tensor: if baselines is None: sens = sensitivity_max(