From 2b6f924431d58b8d15e61ef0a5d26041e8e58b67 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 1/6] Add additional overload signatures for shared methods to resolve pyre errors (#1406) Summary: Add a few additional overload signatures to shared methods for resolving pyre errors Also remove separate cases for typing Literal since the split was necessary due to previous support for Python < 3.8 Reviewed By: csauper Differential Revision: D64677349 --- captum/_utils/common.py | 4 ++++ captum/_utils/typing.py | 16 +--------------- captum/attr/_utils/common.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index f1b5fd9a7e..6336fc4a8a 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... def _is_tuple(inputs: Tensor) -> Literal[False]: ... +@typing.overload +def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ... + + def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: return isinstance(inputs, tuple) diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 598c031b2c..5381350033 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -2,25 +2,11 @@ # pyre-strict -from typing import ( - List, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module -if TYPE_CHECKING: - from typing import Literal -else: - Literal = {True: bool, False: bool, (True, False): bool, "pt": str} - TensorOrTupleOfTensorsGeneric = TypeVar( "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] ) diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 92c1ccafb7..9cb38b10d8 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... +@typing.overload +def _format_input_baseline( # type: ignore + inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType +) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... + + def _format_input_baseline( inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: @@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... +@typing.overload +def _compute_conv_delta_and_format_attrs( + attr_algo: "GradientAttribution", + return_convergence_delta: bool, + attributions: Tuple[Tensor, ...], + start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], + end_point: Union[Tensor, Tuple[Tensor, ...]], + additional_forward_args: Any, + target: TargetType, + is_inputs_tuple: bool = False, +) -> Union[ + Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] +]: ... + + # FIXME: GradientAttribution is provided as a string due to a circular import. # This should be fixed when common is refactored into separate files. def _compute_conv_delta_and_format_attrs( From 4bf1abac9f4732c15b6862a19a26713a7ff70a06 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 2/6] Fix pyre errors in Shapley Values (#1405) Summary: Initial work on fixing Pyre errors in Shapley Values Reviewed By: craymichael Differential Revision: D64677339 --- captum/attr/_core/shapley_value.py | 96 +++++++++--------------------- 1 file changed, 29 insertions(+), 67 deletions(-) diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 081bd75067..8f8d791377 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Any, Callable, cast, Iterable, Sequence, Tuple, Union +from typing import Callable, cast, Iterable, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -56,9 +56,7 @@ def _shape_feature_mask( f"input shape: {inp.shape}, feature mask shape {mask.shape}" ) if mask.dim() < inp.dim(): - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, - # ...]` and `Size`. - mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + mask.shape) + mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + tuple(mask.shape)) mask_list.append(mask) @@ -89,8 +87,7 @@ class ShapleyValueSampling(PerturbationAttribution): https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None: r""" Args: @@ -111,8 +108,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, n_samples: int = 25, perturbations_per_eval: int = 1, @@ -301,45 +297,25 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs, baselines = _format_input_baseline(inputs, baselines) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) additional_forward_args = _format_additional_forward_args( additional_forward_args ) - # pyre-fixme[9]: feature_mask has type - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - feature_mask = _format_feature_mask(feature_mask, inputs) - # pyre-fixme[9]: feature_mask has type - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - feature_mask = _shape_feature_mask(feature_mask, inputs) + formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple) + reshaped_feature_mask = _shape_feature_mask( + formatted_feature_mask, inputs_tuple + ) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 ), "Ablations per evaluation must be at least 1." with torch.no_grad(): - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - baselines = _tensorize_baseline(inputs, baselines) - num_examples = inputs[0].shape[0] + baselines = _tensorize_baseline(inputs_tuple, baselines) + num_examples = inputs_tuple[0].shape[0] - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - total_features = _get_max_feature_index(feature_mask) + 1 + total_features = _get_max_feature_index(reshaped_feature_mask) + 1 if show_progress: attr_progress = progress( @@ -362,7 +338,7 @@ def attribute( initial_eval, num_examples, perturbations_per_eval, - feature_mask, + reshaped_feature_mask, allow_multi_outputs=True, ) @@ -372,11 +348,11 @@ def attribute( # attr shape (*output_shape, *input_feature_shape) total_attrib = [ torch.zeros( - output_shape + input.shape[1:], + tuple(output_shape) + tuple(input.shape[1:]), dtype=torch.float, - device=inputs[0].device, + device=inputs_tuple[0].device, ) - for input in inputs + for input in inputs_tuple ] iter_count = 0 @@ -393,17 +369,11 @@ def attribute( current_target, current_masks, ) in self._perturbation_generator( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` - # but got `TensorOrTupleOfTensorsGeneric`. - inputs, + inputs_tuple, additional_forward_args, target, baselines, - # pyre-fixme[6]: For 5th argument expected - # `TensorOrTupleOfTensorsGeneric` but got - # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, - # typing.Tuple[Tensor, ...]]]]`. - feature_mask, + reshaped_feature_mask, feature_permutation, perturbations_per_eval, ): @@ -445,10 +415,8 @@ def attribute( # have the same dim as the mask tensor. formatted_eval_diff = eval_diff.reshape( (-1,) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int]` and `Size`. - + output_shape - + (len(inputs[j].shape) - 1) * (1,) + + tuple(output_shape) + + (len(inputs_tuple[j].shape) - 1) * (1,) ) # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) @@ -460,11 +428,9 @@ def attribute( # ) cur_mask = current_masks[j] cur_mask = cur_mask.reshape( - cur_mask.shape[:2] + tuple(cur_mask.shape[:2]) + (len(output_shape) - 1) * (1,) - # pyre-fixme[58]: `+` is not supported for operand types - # `Tuple[int, ...]` and `Size`. - + cur_mask.shape[2:] + + tuple(cur_mask.shape[2:]) ) # aggregate n_perturb @@ -495,18 +461,16 @@ def attribute_future(self) -> Callable: "attribute_future is not implemented for ShapleyValueSampling" ) - # pyre-fixme[3]: Return annotation cannot contain `Any`. def _perturbation_generator( self, inputs: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_args: Any, + additional_args: object, target: TargetType, baselines: Tuple[Tensor, ...], input_masks: TensorOrTupleOfTensorsGeneric, feature_permutation: Sequence[int], perturbations_per_eval: int, - ) -> Iterable[Tuple[Tuple[Tensor, ...], Any, TargetType, Tuple[Tensor, ...]]]: + ) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]: """ This method is a generator which yields each perturbation to be evaluated including inputs, additional_forward_args, targets, and mask. @@ -578,9 +542,9 @@ def _perturbation_generator( combined_masks, ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): + def _get_n_evaluations( + self, total_features: int, n_samples: int, perturbations_per_eval: int + ) -> int: """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * n_samples @@ -642,8 +606,7 @@ class ShapleyValues(ShapleyValueSampling): evaluations, and we plan to add this approach in the future. """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Union[int, float, Tensor]]) -> None: r""" Args: @@ -664,8 +627,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, From 8921ef703759535fa6c4a172c0217f4461b34a54 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 3/6] Fix pyre errors in Saliency (#1404) Summary: Initial work on fixing Pyre errors in Shapley Values Reviewed By: craymichael Differential Revision: D64677352 --- captum/attr/_core/saliency.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py index 29205725c0..8698099db7 100644 --- a/captum/attr/_core/saliency.py +++ b/captum/attr/_core/saliency.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import Any, Callable +from typing import Callable import torch from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -13,6 +13,7 @@ from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.log import log_usage +from torch import Tensor class Saliency(GradientAttribution): @@ -25,8 +26,7 @@ class Saliency(GradientAttribution): https://arxiv.org/abs/1312.6034 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -41,8 +41,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, abs: bool = True, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, ) -> TensorOrTupleOfTensorsGeneric: r""" Args: @@ -124,29 +123,21 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + inputs_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(inputs_tuple) # No need to format additional_forward_args here. # They are being formated in the `_run_forward` function in `common.py` gradients = self.gradient_func( - self.forward_func, inputs, target, additional_forward_args + self.forward_func, inputs_tuple, target, additional_forward_args ) if abs: attributions = tuple(torch.abs(gradient) for gradient in gradients) else: attributions = gradients - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(inputs_tuple, gradient_mask) # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) From b84d7baebfd3ad05617ebe160f701c68921eaee9 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 4/6] Fix pyre errors in Occlusion (#1403) Summary: Initial work on fixing Pyre errors in Occlusion Reviewed By: craymichael Differential Revision: D64677342 --- captum/attr/_core/occlusion.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 33c1531108..62ac38e84d 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -37,8 +37,7 @@ class Occlusion(FeatureAblation): /tensorflow/methods.py#L401 """ - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def __init__(self, forward_func: Callable) -> None: + def __init__(self, forward_func: Callable[..., Tensor]) -> None: r""" Args: @@ -58,8 +57,7 @@ def attribute( # type: ignore ] = None, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, perturbations_per_eval: int = 1, show_progress: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -377,9 +375,7 @@ def _occlusion_mask( padded_tensor = torch.nn.functional.pad( sliding_window_tsr, tuple(pad_values) # type: ignore ) - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and - # `Size`. - return padded_tensor.reshape((1,) + padded_tensor.shape) + return padded_tensor.reshape((1,) + tuple(padded_tensor.shape)) def _get_feature_range_and_mask( self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any @@ -389,8 +385,7 @@ def _get_feature_range_and_mask( def _get_feature_counts( self, - # pyre-fixme[2]: Parameter must be annotated. - inputs, + inputs: TensorOrTupleOfTensorsGeneric, feature_mask: Tuple[Tensor, ...], **kwargs: Any, ) -> Tuple[int, ...]: From e7bba25d0cd206cc9a47ad9d30216f252fcf33db Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 5/6] Fix pyre errors in NoiseTunnel (#1402) Summary: Initial work on fixing Pyre errors in Noise Tunnel Reviewed By: craymichael Differential Revision: D64677341 --- captum/attr/_core/noise_tunnel.py | 49 ++++++++++++------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index 7247ccc00d..5d9eb19626 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -27,8 +27,7 @@ class NoiseTunnelType(Enum): vargrad = 3 -# pyre-fixme[5]: Global expression must be annotated. -SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys()) +SUPPORTED_NOISE_TUNNEL_TYPES: List[str] = list(NoiseTunnelType.__members__.keys()) class NoiseTunnel(Attribution): @@ -58,6 +57,10 @@ class NoiseTunnel(Attribution): It is assumed that the batch size is the first dimension of input tensors. """ + is_delta_supported: bool + _multiply_by_inputs: bool + is_gradient_method: bool + def __init__(self, attribution_method: Attribution) -> None: r""" Args: @@ -66,19 +69,15 @@ def __init__(self, attribution_method: Attribution) -> None: Conductance or Saliency. """ self.attribution_method = attribution_method - # pyre-fixme[4]: Attribute must be annotated. self.is_delta_supported = self.attribution_method.has_convergence_delta() - # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs - # pyre-fixme[4]: Attribute must be annotated. self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @log_usage() @@ -205,9 +204,10 @@ def attribute( nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib ) - sum_attributions: List[Union[None, Tensor]] = [] - sum_attributions_sq: List[Union[None, Tensor]] = [] + sum_attributions: Sequence[Union[None, Tensor]] = [] + sum_attributions_sq: Sequence[Union[None, Tensor]] = [] delta_partial_list: List[Tensor] = [] + is_attrib_tuple = is_inputs_tuple for _ in range(nt_samples_partition): inputs_with_noise = self._add_noise_to_inputs( @@ -225,11 +225,7 @@ def attribute( ) if len(sum_attributions) == 0: - # pyre-fixme[9]: sum_attributions has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions = [None] * len(attributions_partial) - # pyre-fixme[9]: sum_attributions_sq has type - # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions_sq = [None] * len(attributions_partial) self._update_partial_attribution_and_delta( @@ -297,7 +293,6 @@ def attribute( return self._apply_checks_and_return_attributions( attributions, - # pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined. is_attrib_tuple, return_convergence_delta, delta, @@ -348,9 +343,7 @@ def _add_noise_to_input( bsz = input.shape[0] # expand input size by the number of drawn samples - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` - # and `Size`. - input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] + input_expanded_size = (bsz * nt_samples_partition,) + tuple(input.shape[1:]) # expand stdev for the shape of the input and number of drawn samples stdev_expanded = torch.tensor(stdev, device=input.device).repeat( @@ -375,14 +368,13 @@ def _update_sum_attribution_and_sq( bsz = attribution.shape[0] // nt_samples_batch_size_inter attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter)) if len(attribution.shape) > 1: - # pyre-fixme[22]: The cast is redundant. - attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) + attribution_shape += tuple(attribution.shape[1:]) attribution = attribution.view(attribution_shape) current_attribution_sum = attribution.sum(dim=1, keepdim=False) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) + current_attribution_sq = torch.sum( + torch.pow(attribution, 2), dim=1, keepdim=False + ) sum_attribution[i] = ( current_attribution_sum @@ -398,8 +390,7 @@ def _update_sum_attribution_and_sq( def _compute_partial_attribution( self, inputs_with_noise_partition: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - kwargs_partition: Any, + kwargs_partition: object, is_inputs_tuple: bool, return_convergence_delta: bool, ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: @@ -505,14 +496,12 @@ def _apply_checks_and_return_attributions( ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] ]: - # pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor, - # ...]]`, expected a tuple. - attributions = _format_output(is_attrib_tuple, attributions) + attributions_tuple = _format_output(is_attrib_tuple, attributions) ret = ( - (attributions, cast(Tensor, delta)) + (attributions_tuple, cast(Tensor, delta)) if self.is_delta_supported and return_convergence_delta - else attributions + else attributions_tuple ) ret = cast( # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: From 9e4c91e0f201bb541fd65a4d8ce28a14800acc3f Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Tue, 22 Oct 2024 10:07:44 -0700 Subject: [PATCH 6/6] Fix pyre errors in LRP (#1401) Summary: Initial work on fixing Pyre errors in LRP Reviewed By: craymichael Differential Revision: D64677351 --- captum/attr/_core/lrp.py | 72 ++++++++++------------------------------ 1 file changed, 18 insertions(+), 54 deletions(-) diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 06c2fd5ae7..d08b7b4de8 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -4,7 +4,7 @@ import typing from collections import defaultdict -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Literal, Tuple, Union import torch.nn as nn from captum._utils.common import ( @@ -18,7 +18,7 @@ apply_gradient_requirements, undo_gradient_requirements, ) -from captum._utils.typing import Literal, TargetType, TensorOrTupleOfTensorsGeneric +from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import GradientAttribution from captum.attr._utils.common import _sum_rows from captum.attr._utils.custom_modules import Addition_Module @@ -43,6 +43,12 @@ class LRP(GradientAttribution): Ancona et al. [https://openreview.net/forum?id=Sy21R9JAW]. """ + verbose: bool = False + _original_state_dict: Dict[str, Any] = {} + layers: List[Module] = [] + backward_handles: List[RemovableHandle] = [] + forward_handles: List[RemovableHandle] = [] + def __init__(self, model: Module) -> None: r""" Args: @@ -62,33 +68,22 @@ def multiplies_by_inputs(self) -> bool: return True @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, + additional_forward_args: object = None, *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], verbose: bool = False, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload - # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `65`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. - # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + additional_forward_args: object = None, return_convergence_delta: Literal[False] = False, verbose: bool = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -100,7 +95,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - additional_forward_args: Any = None, + additional_forward_args: object = None, return_convergence_delta: bool = False, verbose: bool = False, ) -> Union[ @@ -199,35 +194,22 @@ def attribute( >>> attribution = lrp.attribute(input, target=5) """ - # pyre-fixme[16]: `LRP` has no attribute `verbose`. self.verbose = verbose - # pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() - # pyre-fixme[16]: `LRP` has no attribute `layers`. - self.layers: List[Module] = [] + self.layers = [] self._get_layers(self.model) self._check_and_attach_rules() - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles: List[RemovableHandle] = [] - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles: List[RemovableHandle] = [] - # pyre-fixme[6]: For 1st argument expected `Tensor` but got - # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) - # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as - # `Tuple[Tensor, ...]`. - inputs = _format_tensor_into_tuples(inputs) - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - gradient_mask = apply_gradient_requirements(inputs) + input_tuple = _format_tensor_into_tuples(inputs) + gradient_mask = apply_gradient_requirements(input_tuple) try: # 1. Forward pass: Change weights of layers according to selected rules. output = self._compute_output_and_change_weights( - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but - # got `TensorOrTupleOfTensorsGeneric`. - inputs, + input_tuple, target, additional_forward_args, ) @@ -235,7 +217,7 @@ def attribute( # propagation and execute back-propagation. self._register_forward_hooks() normalized_relevances = self.gradient_func( - self._forward_fn_wrapper, inputs, target, additional_forward_args + self._forward_fn_wrapper, input_tuple, target, additional_forward_args ) relevances = tuple( normalized_relevance @@ -245,9 +227,7 @@ def attribute( finally: self._restore_model() - # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got - # `TensorOrTupleOfTensorsGeneric`. - undo_gradient_requirements(inputs, gradient_mask) + undo_gradient_requirements(input_tuple, gradient_mask) if return_convergence_delta: # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... @@ -310,13 +290,11 @@ def compute_convergence_delta( def _get_layers(self, model: Module) -> None: for layer in model.children(): if len(list(layer.children())) == 0: - # pyre-fixme[16]: `LRP` has no attribute `layers`. self.layers.append(layer) else: self._get_layers(layer) def _check_and_attach_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): layer.activations = {} # type: ignore @@ -355,50 +333,41 @@ def _check_rules(self) -> None: ) def _register_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: backward_handles = _register_backward_hook( layer, PropagationRule.backward_hook_activation, self ) - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles.extend(backward_handles) else: forward_handle = layer.register_forward_hook( layer.rule.forward_hook # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) - # pyre-fixme[16]: `LRP` has no attribute `verbose`. if self.verbose: print(f"Applied {layer.rule} on layer {layer}") def _register_weight_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_hook( layer.rule.forward_hook_weights # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _register_pre_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_pre_hook( layer.rule.forward_pre_hook_activations # type: ignore ) - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _compute_output_and_change_weights( self, inputs: Tuple[Tensor, ...], target: TargetType, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - additional_forward_args: Any, + additional_forward_args: object, ) -> Tensor: try: self._register_weight_hooks() @@ -416,15 +385,12 @@ def _compute_output_and_change_weights( return cast(Tensor, output) def _remove_forward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. for forward_handle in self.forward_handles: forward_handle.remove() def _remove_backward_hooks(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. for backward_handle in self.backward_handles: backward_handle.remove() - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer.rule, "_handle_input_hooks"): for handle in layer.rule._handle_input_hooks: # type: ignore @@ -433,13 +399,11 @@ def _remove_backward_hooks(self) -> None: layer.rule._handle_output_hook.remove() # type: ignore def _remove_rules(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): del layer.rule def _clear_properties(self) -> None: - # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "activation"): del layer.activation