diff --git a/captum/_utils/common.py b/captum/_utils/common.py index f1b5fd9a7e..6336fc4a8a 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... def _is_tuple(inputs: Tensor) -> Literal[False]: ... +@typing.overload +def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ... + + def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: return isinstance(inputs, tuple) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 598c031b2c..5381350033 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,25 +2,11 @@ # pyre-strict -from typing import ( - List, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module -if TYPE_CHECKING: - from typing import Literal -else: - Literal = {True: bool, False: bool, (True, False): bool, "pt": str} - TensorOrTupleOfTensorsGeneric = TypeVar( "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] ) diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index 60359bc0dd..5dc444e455 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import Any, Callable, List, Tuple, Union +from typing import Callable, List, Tuple, Union import torch import torch.nn.functional as F @@ -45,8 +45,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Computes attribution by overriding relu gradients. Based on constructor @@ -58,16 +57,10 @@ def attribute( # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) # set hooks for overriding ReLU gradients warnings.warn( @@ -79,14 +72,12 @@ def attribute( self.model.apply(self._register_hooks) gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) finally: self._remove_hooks() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, gradients) @@ -155,8 +146,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -265,8 +255,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: diff --git a/captum/attr/_core/guided_grad_cam.py b/captum/attr/_core/guided_grad_cam.py index bb9beb6a0b..d4d1978496 100644 --- a/captum/attr/_core/guided_grad_cam.py +++ b/captum/attr/_core/guided_grad_cam.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import Any, List, Union +from typing import List, Union import torch from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -72,8 +72,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, interpolate_mode: str = "nearest", attribute_to_layer_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -181,15 +180,11 @@ def attribute( >>> # attribution size matches input size, Nx3x32x32 >>> attribution = guided_gc.attribute(input, 3) """ - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) grad_cam_attr = self.grad_cam.attribute.__wrapped__( self.grad_cam, # self - inputs=inputs, + inputs=inputs_tuple, target=target, additional_forward_args=additional_forward_args, attribute_to_layer_input=attribute_to_layer_input, @@ -204,20 +199,18 @@ def attribute( guided_backprop_attr = self.guided_backprop.attribute.__wrapped__( self.guided_backprop, # self - inputs=inputs, + inputs=inputs_tuple, target=target, additional_forward_args=additional_forward_args, ) output_attr: List[Tensor] = [] - for i in range(len(inputs)): + for i in range(len(inputs_tuple)): try: output_attr.append( guided_backprop_attr[i] * LayerAttribution.interpolate( grad_cam_attr, - # pyre-fixme[6]: For 2nd argument expected `Union[int, - # typing.Tuple[int, ...]]` but got `Size`. - inputs[i].shape[2:], + tuple(inputs_tuple[i].shape[2:]), interpolate_mode=interpolate_mode, ) ) diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py index 86115bb03b..bfaa75def2 100644 --- a/captum/attr/_core/input_x_gradient.py +++ b/captum/attr/_core/input_x_gradient.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable +from typing import Callable from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple from captum._utils.gradient import ( @@ -11,6 +11,7 @@ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.log import log_usage +from torch import Tensor class InputXGradient(GradientAttribution): @@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution): https://arxiv.org/abs/1605.01713 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -35,8 +35,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -113,28 +112,20 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) attributions = tuple( - input * gradient for input, gradient in zip(inputs, gradients) + input * gradient for input, gradient in zip(inputs_tuple, gradients) ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index e803262937..1abbcc69f7 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -2,7 +2,7 @@ # pyre-strict import typing -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Literal, Tuple, Union import torch from captum._utils.common import ( @@ -12,12 +12,7 @@ _format_output, _is_tuple, ) -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.approximation_methods import approximation_parameters from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.batching import _batch_attribution @@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], multiply_by_inputs: bool = True, ) -> None: r""" @@ -80,21 +74,16 @@ def __init__( # and when return_convergence_delta is True, the return type is # a tuple with both attributions and deltas. @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `95`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @@ -111,9 +100,6 @@ def attribute( n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -275,37 +261,35 @@ def attribute( # type: ignore """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as # `Tuple[Tensor, ...]`. - inputs, baselines = _format_input_baseline(inputs, baselines) + formatted_inputs, formatted_baselines = _format_input_baseline( + inputs, baselines + ) # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got # `TensorOrTupleOfTensorsGeneric`. - _validate_input(inputs, baselines, n_steps, method) + _validate_input(formatted_inputs, formatted_baselines, n_steps, method) if internal_batch_size is not None: - num_examples = inputs[0].shape[0] + num_examples = formatted_inputs[0].shape[0] attributions = _batch_attribution( self, num_examples, internal_batch_size, n_steps, - inputs=inputs, - baselines=baselines, + inputs=formatted_inputs, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, method=method, ) else: attributions = self._attribute( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs=inputs, - baselines=baselines, + inputs=formatted_inputs, + baselines=formatted_baselines, target=target, additional_forward_args=additional_forward_args, n_steps=n_steps, @@ -344,8 +328,7 @@ def _attribute( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_steps: int = 50, method: str = "gausslegendre", step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index 8b6fb44cbf..89d22990d8 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Any, Callable, Generator, Tuple, Union +from typing import Callable, cast, Generator, Tuple, Union import torch from captum._utils.models.linear_model import SkLearnLinearRegression @@ -27,8 +27,7 @@ class KernelShap(Lime): https://arxiv.org/abs/1705.07874 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -50,8 +49,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -279,10 +277,7 @@ def attribute( # type: ignore ) num_features_list = torch.arange(num_interp_features, dtype=torch.float) denom = num_features_list * (num_interp_features - num_features_list) - # pyre-fixme[58]: `/` is not supported for operand types - # `int` and `torch._tensor.Tensor`. - probs = (num_interp_features - 1) / denom - # pyre-fixme[16]: `float` has no attribute `__setitem__`. + probs = torch.tensor((num_interp_features - 1)) / denom probs[0] = 0.0 return self._attribute_kwargs( inputs=inputs, @@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel( _, __, interpretable_sample: Tensor, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Tensor: assert ( "num_interp_features" in kwargs @@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel( def kernel_shap_perturb_generator( self, original_inp: Union[Tensor, Tuple[Tensor, ...]], - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Generator[Tensor, None, None]: r""" Perturbations are sampled by the following process: @@ -361,11 +354,13 @@ def kernel_shap_perturb_generator( device = original_inp.device else: device = original_inp[0].device - num_features = kwargs["num_interp_features"] + num_features = cast(int, kwargs["num_interp_features"]) yield torch.ones(1, num_features, device=device, dtype=torch.long) yield torch.zeros(1, num_features, device=device, dtype=torch.long) while True: - num_selected_features = kwargs["num_select_distribution"].sample() + num_selected_features = cast( + Categorical, kwargs["num_select_distribution"] + ).sample() rand_vals = torch.randn(1, num_features) threshold = torch.kthvalue( rand_vals, num_features - num_selected_features diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 21bae8677e..152bae2c3a 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -6,7 +6,7 @@ import typing import warnings from collections.abc import Iterator -from typing import Any, Callable, cast, List, Optional, Tuple, Union +from typing import Any, Callable, cast, List, Literal, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -23,12 +23,7 @@ from captum._utils.models.linear_model import SkLearnLasso from captum._utils.models.model import Model from captum._utils.progress import progress -from captum._utils.typing import ( - BaselineType, - Literal, - TargetType, - TensorOrTupleOfTensorsGeneric, -) +from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.batching import _batch_example_iterator from captum.attr._utils.common import ( @@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], interpretable_model: Model, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - similarity_func: Callable, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - perturb_func: Callable, + similarity_func: Callable[ + ..., + Union[float, Tensor], + ], + perturb_func: Callable[..., object], perturb_interpretable_space: bool, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - from_interp_rep_transform: Optional[Callable], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - to_interp_rep_transform: Optional[Callable], + from_interp_rep_transform: Optional[ + Callable[..., Union[Tensor, Tuple[Tensor, ...]]] + ], + to_interp_rep_transform: Optional[Callable[..., Tensor]], ) -> None: r""" @@ -249,13 +244,11 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, n_samples: int = 50, perturbations_per_eval: int = 1, show_progress: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> Tensor: r""" This method attributes the output of the model with given target index @@ -551,7 +544,7 @@ def generate_perturbation() -> ( curr_sample, inputs, **kwargs ) - return interpretable_inp, curr_model_input + return interpretable_inp, curr_model_input # type: ignore return generate_perturbation @@ -568,8 +561,7 @@ def _evaluate_batch( self, curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], expanded_target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - expanded_additional_args: Any, + expanded_additional_args: object, device: torch.device, ) -> Tensor: model_out = _run_forward( @@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): def get_exp_kernel_similarity_function( distance_mode: str = "cosine", kernel_width: float = 1.0, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -) -> Callable: +) -> Callable[..., float]: r""" This method constructs an appropriate similarity function to compute weights for perturbed sample in LIME. Distance between the original @@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): return default_exp_kernel -# pyre-fixme[2]: Parameter must be annotated. -def default_perturb_func(original_inp, **kwargs) -> Tensor: +def default_perturb_func( + original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object +) -> Tensor: assert ( "num_interp_features" in kwargs ), "Must provide num_interp_features to use default interpretable sampling function" @@ -690,7 +682,7 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor: else: device = original_inp[0].device - probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5 + probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5 return torch.bernoulli(probs).to(device=device).long() @@ -698,17 +690,17 @@ def construct_feature_mask( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]], formatted_inputs: Tuple[Tensor, ...], ) -> Tuple[Tuple[Tensor, ...], int]: + feature_mask_tuple: Tuple[Tensor, ...] if feature_mask is None: - feature_mask, num_interp_features = _construct_default_feature_mask( + feature_mask_tuple, num_interp_features = _construct_default_feature_mask( formatted_inputs ) else: - feature_mask = _format_tensor_into_tuples(feature_mask) + feature_mask_tuple = _format_tensor_into_tuples(feature_mask) min_interp_features = int( min( torch.min(single_mask).item() - # pyre-fixme[16]: `None` has no attribute `__iter__`. - for single_mask in feature_mask + for single_mask in feature_mask_tuple if single_mask.numel() ) ) @@ -718,14 +710,12 @@ def construct_feature_mask( " start at 0.", stacklevel=2, ) - feature_mask = tuple( - single_mask - min_interp_features for single_mask in feature_mask + feature_mask_tuple = tuple( + single_mask - min_interp_features for single_mask in feature_mask_tuple ) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[typing.Tuple[typing.Any, ...]]`. - num_interp_features = _get_max_feature_index(feature_mask) + 1 - return feature_mask, num_interp_features + num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1 + return feature_mask_tuple, num_interp_features class Lime(LimeBase): @@ -766,8 +756,7 @@ class Lime(LimeBase): def __init__( self, - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - forward_func: Callable, + forward_func: Callable[..., Tensor], interpretable_model: Optional[Model] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. similarity_func: Optional[Callable] = None, @@ -887,8 +876,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, return_input_shape: bool = True, show_progress: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + **kwargs: object, ) -> TensorOrTupleOfTensorsGeneric: - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) bsz = formatted_inputs[0].shape[0] @@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore return coefs @typing.overload - # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept - # all possible arguments of overload defined on line `1201`. def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], ) -> Tuple[Tensor, ...]: ... @typing.overload - # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept - # all possible arguments of overload defined on line `1211`. def _convert_output_shape( # type: ignore self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False], ) -> Tensor: ... + @typing.overload + def _convert_output_shape( + self, + formatted_inp: Tuple[Tensor, ...], + feature_mask: Tuple[Tensor, ...], + coefs: Tensor, + num_interp_features: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: ... + def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...], diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 06c2fd5ae7..d08b7b4de8 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -4,7 +4,7 @@ import typing from collections import defaultdict -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union import torch.nn as nn from captum._utils.common import ( @@ -18,7 +18,7 @@ apply_gradient_requirements, undo_gradient_requirements, ) -from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.common import _sum_rows from captum.attr._utils.custom_modules import Addition_Module @@ -43,6 +43,12 @@ class LRP(GradientAttribution): Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. """ + verbose: bool = False + _original_state_dict: Dict[str, Any] = {} + layers: List[Module] = [] + backward_handles: List[RemovableHandle] = [] + forward_handles: List[RemovableHandle] = [] + def __init__(self, model: Module) -> None: r""" Args: @@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool: return True @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], verbose: bool = False, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `65`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, verbose: bool = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -100,7 +95,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, verbose: bool = False, ) -> Union[ @@ -199,35 +194,22 @@ def attribute( >>> attribution = lrp.attribute(input, target=5) """ - # pyre-fixme[16]: `LRP` has no attribute `verbose`. self.verbose = verbose - # pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() - # pyre-fixme[16]: `LRP` has no attribute `layers`. - self.layers: List[Module] = [] + self.layers = [] self._get_layers(self.model) self._check_and_attach_rules() - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles: List[RemovableHandle] = [] - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles: List[RemovableHandle] = [] - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + input_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(input_tuple) try: # 1. Forward pass: Change weights of layers according to selected rules. output = self._compute_output_and_change_weights( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs, + input_tuple, target, additional_forward_args, ) @@ -235,7 +217,7 @@ def attribute( # propagation and execute back-propagation. self._register_forward_hooks() normalized_relevances = self.gradient_func( - self._forward_fn_wrapper, inputs, target, additional_forward_args + self._forward_fn_wrapper, input_tuple, target, additional_forward_args ) relevances = tuple( normalized_relevance @@ -245,9 +227,7 @@ def attribute( finally: self._restore_model() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(input_tuple, gradient_mask) if return_convergence_delta: # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... @@ -310,13 +290,11 @@ def compute_convergence_delta( def _get_layers(self, model: Module) -> None: for layer in model.children(): if len(list(layer.children())) == 0: - # pyre-fixme[16]: `LRP` has no attribute `layers`. self.layers.append(layer) else: self._get_layers(layer) def _check_and_attach_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): layer.activations = {} # type: ignore @@ -355,50 +333,41 @@ def _check_rules(self) -> None: ) def _register_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: backward_handles = _register_backward_hook( layer, PropagationRule.backward_hook_activation, self ) - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles.extend(backward_handles) else: forward_handle = layer.register_forward_hook( layer.rule.forward_hook # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) - # pyre-fixme[16]: `LRP` has no attribute `verbose`. if self.verbose: print(f"Applied {layer.rule} on layer {layer}") def _register_weight_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_hook( layer.rule.forward_hook_weights # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _register_pre_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_pre_hook( layer.rule.forward_pre_hook_activations # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _compute_output_and_change_weights( self, inputs: Tuple[Tensor, ...], target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, + additional_forward_args: object, ) -> Tensor: try: self._register_weight_hooks() @@ -416,15 +385,12 @@ def _compute_output_and_change_weights( return cast(Tensor, output) def _remove_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. for forward_handle in self.forward_handles: forward_handle.remove() def _remove_backward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. for backward_handle in self.backward_handles: backward_handle.remove() - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer.rule, "_handle_input_hooks"): for handle in layer.rule._handle_input_hooks: # type: ignore @@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None: layer.rule._handle_output_hook.remove() # type: ignore def _remove_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): del layer.rule def _clear_properties(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "activation"): del layer.activation diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index 7247ccc00d..5d9eb19626 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -27,8 +27,7 @@ class NoiseTunnelType(Enum): vargrad = 3 -# pyre-fixme[5]: Global expression must be annotated. -SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys()) +SUPPORTED_NOISE_TUNNEL_TYPES: List[str] = list(NoiseTunnelType.__members__.keys()) class NoiseTunnel(Attribution): @@ -58,6 +57,10 @@ class NoiseTunnel(Attribution): It is assumed that the batch size is the first dimension of input tensors. """ + is_delta_supported: bool + _multiply_by_inputs: bool + is_gradient_method: bool + def __init__(self, attribution_method: Attribution) -> None: r""" Args: @@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None: Conductance or Saliency. """ self.attribution_method = attribution_method - # pyre-fixme[4]: Attribute must be annotated. self.is_delta_supported = self.attribution_method.has_convergence_delta() - # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs - # pyre-fixme[4]: Attribute must be annotated. self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @log_usage() @@ -205,9 +204,10 @@ def attribute( nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib ) - sum_attributions: List[Union[None, Tensor]] = [] - sum_attributions_sq: List[Union[None, Tensor]] = [] + sum_attributions: Sequence[Union[None, Tensor]] = [] + sum_attributions_sq: Sequence[Union[None, Tensor]] = [] delta_partial_list: List[Tensor] = [] + is_attrib_tuple = is_inputs_tuple for _ in range(nt_samples_partition): inputs_with_noise = self._add_noise_to_inputs( @@ -225,11 +225,7 @@ def attribute( ) if len(sum_attributions) == 0: - # pyre-fixme[9]: sum_attributions has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions = [None] * len(attributions_partial) - # pyre-fixme[9]: sum_attributions_sq has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions_sq = [None] * len(attributions_partial) self._update_partial_attribution_and_delta( @@ -297,7 +293,6 @@ def attribute( return self._apply_checks_and_return_attributions( attributions, - # pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined. is_attrib_tuple, return_convergence_delta, delta, @@ -348,9 +343,7 @@ def _add_noise_to_input( bsz = input.shape[0] # expand input size by the number of drawn samples - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` - # and `Size`. - input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] + input_expanded_size = (bsz * nt_samples_partition,) + tuple(input.shape[1:]) # expand stdev for the shape of the input and number of drawn samples stdev_expanded = torch.tensor(stdev, device=input.device).repeat( @@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq( bsz = attribution.shape[0] // nt_samples_batch_size_inter attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter)) if len(attribution.shape) > 1: - # pyre-fixme[22]: The cast is redundant. - attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) + attribution_shape += tuple(attribution.shape[1:]) attribution = attribution.view(attribution_shape) current_attribution_sum = attribution.sum(dim=1, keepdim=False) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) + current_attribution_sq = torch.sum( + torch.pow(attribution, 2), dim=1, keepdim=False + ) sum_attribution[i] = ( current_attribution_sum @@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq( def _compute_partial_attribution( self, inputs_with_noise_partition: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - kwargs_partition: Any, + kwargs_partition: object, is_inputs_tuple: bool, return_convergence_delta: bool, ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: @@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions( ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] ]: - # pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor, - # ...]]`, expected a tuple. - attributions = _format_output(is_attrib_tuple, attributions) + attributions_tuple = _format_output(is_attrib_tuple, attributions) ret = ( - (attributions, cast(Tensor, delta)) + (attributions_tuple, cast(Tensor, delta)) if self.is_delta_supported and return_convergence_delta - else attributions + else attributions_tuple ) ret = cast( # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 33c1531108..62ac38e84d 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -37,8 +37,7 @@ class Occlusion(FeatureAblation): /tensorflow/methods.py#L401 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -58,8 +57,7 @@ def attribute( # type: ignore ] = None, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, perturbations_per_eval: int = 1, show_progress: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -377,9 +375,7 @@ def _occlusion_mask( padded_tensor = torch.nn.functional.pad( sliding_window_tsr, tuple(pad_values) # type: ignore ) - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and - # `Size`. - return padded_tensor.reshape((1,) + padded_tensor.shape) + return padded_tensor.reshape((1,) + tuple(padded_tensor.shape)) def _get_feature_range_and_mask( self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any @@ -389,8 +385,7 @@ def _get_feature_range_and_mask( def _get_feature_counts( self, - # pyre-fixme[2]: Parameter must be annotated. - inputs, + inputs: TensorOrTupleOfTensorsGeneric, feature_mask: Tuple[Tensor, ...], **kwargs: Any, ) -> Tuple[int, ...]: diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py index 29205725c0..8698099db7 100644 --- a/captum/attr/_core/saliency.py +++ b/captum/attr/_core/saliency.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Any, Callable +from typing import Callable import torch from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -13,6 +13,7 @@ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.log import log_usage +from torch import Tensor class Saliency(GradientAttribution): @@ -25,8 +26,7 @@ class Saliency(GradientAttribution): https://arxiv.org/abs/1312.6034 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -41,8 +41,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, abs: bool = True, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -124,29 +123,21 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) # No need to format additional_forward_args here. # They are being formated in the `_run_forward` function in `common.py` gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) if abs: attributions = tuple(torch.abs(gradient) for gradient in gradients) else: attributions = gradients - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 081bd75067..8f8d791377 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union +from typing import Callable, cast, Iterable, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -56,9 +56,7 @@ def _shape_feature_mask( f"input shape: {inp.shape}, feature mask shape {mask.shape}" ) if mask.dim() < inp.dim(): - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, - # ...]` and `Size`. - mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + mask.shape) + mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + tuple(mask.shape)) mask_list.append(mask) @@ -89,8 +87,7 @@ class ShapleyValueSampling(PerturbationAttribution): https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None: r""" Args: @@ -111,8 +108,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -301,45 +297,25 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs, baselines = _format_input_baseline(inputs, baselines) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) additional_forward_args = _format_additional_forward_args( additional_forward_args ) - # pyre-fixme[9]: feature_mask has type - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - feature_mask = _format_feature_mask(feature_mask, inputs) - # pyre-fixme[9]: feature_mask has type - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - feature_mask = _shape_feature_mask(feature_mask, inputs) + formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple) + reshaped_feature_mask = _shape_feature_mask( + formatted_feature_mask, inputs_tuple + ) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Ablations per evaluation must be at least 1." with torch.no_grad(): - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines = _tensorize_baseline(inputs, baselines) - num_examples = inputs[0].shape[0] + baselines = _tensorize_baseline(inputs_tuple, baselines) + num_examples = inputs_tuple[0].shape[0] - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - total_features = _get_max_feature_index(feature_mask) + 1 + total_features = _get_max_feature_index(reshaped_feature_mask) + 1 if show_progress: attr_progress = progress( @@ -362,7 +338,7 @@ def attribute( initial_eval, num_examples, perturbations_per_eval, - feature_mask, + reshaped_feature_mask, allow_multi_outputs=True, ) @@ -372,11 +348,11 @@ def attribute( # attr shape (*output_shape, *input_feature_shape) total_attrib = [ torch.zeros( - output_shape + input.shape[1:], + tuple(output_shape) + tuple(input.shape[1:]), dtype=torch.float, - device=inputs[0].device, + device=inputs_tuple[0].device, ) - for input in inputs + for input in inputs_tuple ] iter_count = 0 @@ -393,17 +369,11 @@ def attribute( current_target, current_masks, ) in self._perturbation_generator( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` - # but got `TensorOrTupleOfTensorsGeneric`. - inputs, + inputs_tuple, additional_forward_args, target, baselines, - # pyre-fixme[6]: For 5th argument expected - # `TensorOrTupleOfTensorsGeneric` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - feature_mask, + reshaped_feature_mask, feature_permutation, perturbations_per_eval, ): @@ -445,10 +415,8 @@ def attribute( # have the same dim as the mask tensor. formatted_eval_diff = eval_diff.reshape( (-1,) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int]` and `Size`. - + output_shape - + (len(inputs[j].shape) - 1) * (1,) + + tuple(output_shape) + + (len(inputs_tuple[j].shape) - 1) * (1,) ) # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) @@ -460,11 +428,9 @@ def attribute( # ) cur_mask = current_masks[j] cur_mask = cur_mask.reshape( - cur_mask.shape[:2] + tuple(cur_mask.shape[:2]) + (len(output_shape) - 1) * (1,) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int, ...]` and `Size`. - + cur_mask.shape[2:] + + tuple(cur_mask.shape[2:]) ) # aggregate n_perturb @@ -495,18 +461,16 @@ def attribute_future(self) -> Callable: "attribute_future is not implemented for ShapleyValueSampling" ) - # pyre-fixme[3]: Return annotation cannot contain `Any`. def _perturbation_generator( self, inputs: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_args: Any, + additional_args: object, target: TargetType, baselines: Tuple[Tensor, ...], input_masks: TensorOrTupleOfTensorsGeneric, feature_permutation: Sequence[int], perturbations_per_eval: int, - ) -> Iterable[Tuple[Tuple[Tensor, ...], Any, TargetType, Tuple[Tensor, ...]]]: + ) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]: """ This method is a generator which yields each perturbation to be evaluated including inputs, additional_forward_args, targets, and mask. @@ -578,9 +542,9 @@ def _perturbation_generator( combined_masks, ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): + def _get_n_evaluations( + self, total_features: int, n_samples: int, perturbations_per_eval: int + ) -> int: """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * n_samples @@ -642,8 +606,7 @@ class ShapleyValues(ShapleyValueSampling): evaluations, and we plan to add this approach in the future. """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None: r""" Args: @@ -664,8 +627,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 92c1ccafb7..9cb38b10d8 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... +@typing.overload +def _format_input_baseline( # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... + + def _format_input_baseline( inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: @@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: bool = False, +) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] +]: ... + + # FIXME: GradientAttribution is provided as a string due to a circular import. # This should be fixed when common is refactored into separate files. def _compute_conv_delta_and_format_attrs(