From 3f0eabcbe003a18fe63d6e23eb0b62179a3a7c16 Mon Sep 17 00:00:00 2001 From: Yuanshuo Cui Date: Tue, 6 Aug 2024 17:06:44 -0700 Subject: [PATCH] Enable pyre for Captum open source part- 2/2 (#1319) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1319 Continue the work in stack of D59693589 Ref: https://www.internalfb.com/intern/wiki/Python/Type-annotations-in-python/How_To:_Migrate_to_Pyre_Fast_By_Default_Per_Target_Type_Checking/ Differential Revision: D60837583 --- captum/_utils/av.py | 9 ++ captum/_utils/common.py | 98 +++++++++++++- captum/_utils/gradient.py | 97 +++++++++++++- captum/_utils/models/__init__.py | 1 + captum/_utils/models/linear_model/__init__.py | 1 + captum/_utils/models/linear_model/model.py | 37 ++++++ captum/_utils/models/linear_model/train.py | 20 +++ captum/_utils/models/model.py | 7 +- captum/_utils/progress.py | 41 ++++++ captum/_utils/sample_gradient.py | 7 + captum/_utils/typing.py | 3 + captum/attr/__init__.py | 2 + captum/attr/_core/dataloader_attr.py | 28 +++- captum/attr/_core/deep_lift.py | 125 +++++++++++++++++- captum/attr/_core/feature_ablation.py | 56 +++++++- captum/attr/_core/feature_permutation.py | 10 +- captum/attr/_core/gradient_shap.py | 42 ++++++ .../attr/_core/guided_backprop_deconvnet.py | 18 +++ captum/attr/_core/guided_grad_cam.py | 11 ++ captum/attr/_core/input_x_gradient.py | 15 +++ captum/attr/_core/integrated_gradients.py | 28 ++++ captum/attr/_core/kernel_shap.py | 16 ++- captum/attr/_core/layer/grad_cam.py | 19 +++ captum/attr/_core/layer/internal_influence.py | 13 +- captum/attr/_core/layer/layer_activation.py | 7 + captum/attr/_core/layer/layer_conductance.py | 20 +++ captum/attr/_core/layer/layer_deep_lift.py | 42 +++++- .../_core/layer/layer_feature_ablation.py | 10 ++ .../_core/layer/layer_feature_permutation.py | 9 ++ .../attr/_core/layer/layer_gradient_shap.py | 38 ++++++ .../layer/layer_gradient_x_activation.py | 9 ++ .../_core/layer/layer_integrated_gradients.py | 40 +++++- captum/attr/_core/layer/layer_lrp.py | 38 +++++- captum/attr/_core/lime.py | 46 ++++++- captum/attr/_core/llm_attr.py | 51 +++++++ captum/attr/_core/lrp.py | 50 ++++++- .../attr/_core/neuron/neuron_conductance.py | 19 +++ captum/attr/_core/neuron/neuron_deep_lift.py | 8 ++ .../_core/neuron/neuron_feature_ablation.py | 6 + captum/attr/_core/neuron/neuron_gradient.py | 15 +++ .../attr/_core/neuron/neuron_gradient_shap.py | 6 + .../neuron_guided_backprop_deconvnet.py | 6 + .../neuron/neuron_integrated_gradients.py | 6 + captum/attr/_core/noise_tunnel.py | 43 +++++- captum/attr/_core/occlusion.py | 8 ++ captum/attr/_core/saliency.py | 14 ++ captum/attr/_core/shapley_value.py | 53 +++++++- captum/attr/_models/base.py | 16 +++ captum/attr/_models/pytext.py | 44 +++++- captum/attr/_utils/approximation_methods.py | 7 + captum/attr/_utils/attribution.py | 20 +++ captum/attr/_utils/baselines.py | 7 + captum/attr/_utils/batching.py | 33 ++++- captum/attr/_utils/class_summarizer.py | 7 + captum/attr/_utils/common.py | 18 +++ captum/attr/_utils/custom_modules.py | 4 + captum/attr/_utils/input_layer_wrapper.py | 7 + captum/attr/_utils/interpretable_input.py | 37 ++++++ captum/attr/_utils/lrp_rules.py | 49 +++++++ captum/attr/_utils/stat.py | 52 +++++++- captum/attr/_utils/summarizer.py | 19 +++ captum/attr/_utils/visualization.py | 42 ++++++ 62 files changed, 1577 insertions(+), 33 deletions(-) diff --git a/captum/_utils/av.py b/captum/_utils/av.py index c5ecc5325f..376bac1f46 100644 --- a/captum/_utils/av.py +++ b/captum/_utils/av.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import glob import os import re @@ -66,12 +68,14 @@ def __init__( which the activation vectors are computed """ + # pyre-fixme[4]: Attribute must be annotated. self.av_filesearch = AV._construct_file_search( path, model_id, identifier, layer, num_id ) files = glob.glob(self.av_filesearch) + # pyre-fixme[4]: Attribute must be annotated. self.files = AV.sort_files(files) def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]: @@ -346,6 +350,7 @@ def _compute_and_save_activations( inputs: Union[Tensor, Tuple[Tensor, ...]], identifier: str, num_id: str, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, load_from_disk: bool = True, ) -> None: @@ -395,6 +400,8 @@ def _compute_and_save_activations( AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id) @staticmethod + # pyre-fixme[3]: Return annotation cannot be `Any`. + # pyre-fixme[2]: Parameter annotation cannot be `Any`. def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any: r""" Helper to extract input from labels when getting items from a Dataset. Assumes @@ -490,6 +497,8 @@ def sort_files(files: List[str]) -> List[str]: lexigraphical sort. """ + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def split_alphanum(s): r""" Splits string into a list of strings and numbers diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 052ccb49e8..e225a33379 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from enum import Enum from functools import reduce @@ -68,10 +70,18 @@ def safe_div( @typing.overload +# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) +# is incompatible with the return type of the implementation (`bool`). +# pyre-fixme[31]: Expression `Literal[False]` is not a valid type. +# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def _is_tuple(inputs: Tensor) -> Literal[False]: ... @typing.overload +# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) +# is incompatible with the return type of the implementation (`bool`). +# pyre-fixme[31]: Expression `Literal[True]` is not a valid type. +# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... @@ -230,6 +240,8 @@ def _format_tensor_into_tuples( return inputs +# pyre-fixme[3]: Return annotation cannot be `Any`. +# pyre-fixme[2]: Parameter annotation cannot be `Any`. def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any: return ( inputs @@ -257,16 +269,21 @@ def _format_additional_forward_args(additional_forward_args: None) -> None: ... @overload def _format_additional_forward_args( + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. additional_forward_args: Union[Tensor, Tuple] + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. ) -> Tuple: ... @overload def _format_additional_forward_args( + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. ) -> Union[None, Tuple]: ... +# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: if additional_forward_args is not None and not isinstance( additional_forward_args, tuple @@ -276,9 +293,11 @@ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, def _expand_additional_forward_args( + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, n_steps: int, expansion_type: ExpansionTypes = ExpansionTypes.repeat, + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. ) -> Union[None, Tuple]: def _expand_tensor_forward_arg( additional_forward_arg: Tensor, @@ -343,9 +362,12 @@ def _expand_target( return target +# pyre-fixme[3]: Return type must be annotated. def _expand_feature_mask( feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int ): + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, + # typing.Tuple[Tensor, ...]]`. is_feature_mask_tuple = _is_tuple(feature_mask) feature_mask = _format_tensor_into_tuples(feature_mask) feature_mask_new = tuple( @@ -359,12 +381,17 @@ def _expand_feature_mask( return _format_output(is_feature_mask_tuple, feature_mask_new) +# pyre-fixme[3]: Return type must be annotated. def _expand_and_update_baselines( inputs: Tuple[Tensor, ...], n_samples: int, + # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use + # `typing.Dict[, ]` to avoid runtime subscripting errors. kwargs: dict, draw_baseline_from_distrib: bool = False, ): + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def get_random_baseline_indices(bsz, baseline): num_ref_samples = baseline.shape[0] return np.random.choice(num_ref_samples, n_samples * bsz).tolist() @@ -404,6 +431,9 @@ def get_random_baseline_indices(bsz, baseline): kwargs["baselines"] = baselines +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use +# `typing.Dict[, ]` to avoid runtime subscripting errors. def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): if "additional_forward_args" not in kwargs: return @@ -420,6 +450,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): kwargs["additional_forward_args"] = additional_forward_args +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use +# `typing.Dict[, ]` to avoid runtime subscripting errors. def _expand_and_update_target(n_samples: int, kwargs: dict): if "target" not in kwargs: return @@ -431,6 +464,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict): kwargs["target"] = target +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use +# `typing.Dict[, ]` to avoid runtime subscripting errors. def _expand_and_update_feature_mask(n_samples: int, kwargs: dict): if "feature_mask" not in kwargs: return @@ -444,14 +480,24 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict): @typing.overload +# pyre-fixme[43]: The implementation of `_format_output` does not accept all +# possible arguments of overload defined on line `449`. def _format_output( - is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...] + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + is_inputs_tuple: Literal[True], + output: Tuple[Tensor, ...], ) -> Tuple[Tensor, ...]: ... @typing.overload +# pyre-fixme[43]: The implementation of `_format_output` does not accept all +# possible arguments of overload defined on line `455`. def _format_output( - is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...] + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + is_inputs_tuple: Literal[False], + output: Tuple[Tensor, ...], ) -> Tensor: ... @@ -474,18 +520,30 @@ def _format_output( "The input is a single tensor however the output isn't." "The number of output tensors is: {}".format(len(output)) ) + # pyre-fixme[7]: Expected `Union[Tensor, typing.Tuple[Tensor, ...]]` but got + # `Union[tuple[Tensor], Tensor]`. return output if is_inputs_tuple else output[0] @typing.overload +# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all +# possible arguments of overload defined on line `483`. def _format_outputs( - is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]] + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + is_multiple_inputs: Literal[False], + outputs: List[Tuple[Tensor, ...]], ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @typing.overload +# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all +# possible arguments of overload defined on line `489`. def _format_outputs( - is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]] + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. + is_multiple_inputs: Literal[True], + outputs: List[Tuple[Tensor, ...]], ) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ... @@ -512,9 +570,12 @@ def _format_outputs( def _run_forward( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. inputs: Any, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> Union[Tensor, Future[Tensor]]: forward_func_args = signature(forward_func).parameters @@ -529,6 +590,8 @@ def _run_forward( output = forward_func( *( + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*inputs, *additional_forward_args`. (*inputs, *additional_forward_args) if additional_forward_args is not None else inputs @@ -606,6 +669,8 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor: elif isinstance(target[0], tuple): return torch.stack( [ + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type + # parameter. output[(i,) + cast(Tuple, targ_elem)] for i, targ_elem in enumerate(target) ] @@ -639,9 +704,11 @@ def _verify_select_column( def _verify_select_neuron( layer_output: Tuple[Tensor, ...], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. selector: Union[int, Tuple[Union[int, slice], ...], Callable], ) -> Tensor: if callable(selector): + # pyre-fixme[7]: Expected `Tensor` but got `object`. return selector(layer_output if len(layer_output) > 1 else layer_output[0]) assert len(layer_output) == 1, ( @@ -688,6 +755,9 @@ def _extract_device( def _reduce_list( val_list: Sequence[TupleOrTensorOrBoolGeneric], + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use + # `typing.List[]` to avoid runtime subscripting errors. red_func: Callable[[List], Any] = torch.cat, ) -> TupleOrTensorOrBoolGeneric: """ @@ -702,14 +772,20 @@ def _reduce_list( """ assert len(val_list) > 0, "Cannot reduce empty list!" if isinstance(val_list[0], torch.Tensor): + # pyre-fixme[16]: `bool` has no attribute `device`. first_device = val_list[0].device + # pyre-fixme[16]: `bool` has no attribute `to`. return red_func([elem.to(first_device) for elem in val_list]) elif isinstance(val_list[0], bool): + # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`. return any(val_list) elif isinstance(val_list[0], tuple): final_out = [] + # pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]` + # but got `TupleOrTensorOrBoolGeneric`. for i in range(len(val_list[0])): final_out.append( + # pyre-fixme[16]: `bool` has no attribute `__getitem__`. _reduce_list([val_elem[i] for val_elem in val_list], red_func) ) else: @@ -717,6 +793,7 @@ def _reduce_list( "Elements to be reduced can only be" "either Tensors or tuples containing Tensors." ) + # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`. return tuple(final_out) @@ -756,6 +833,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor: return torch.cat([single_inp.flatten() for single_inp in inp]) +# pyre-fixme[3]: Return annotation cannot be `Any`. def _get_module_from_name(model: Module, layer_name: str) -> Any: r""" Returns the module (layer) object, given its (string) name @@ -772,7 +850,11 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any: def _register_backward_hook( - module: Module, hook: Callable, attr_obj: Any + module: Module, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + hook: Callable, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + attr_obj: Any, ) -> List[torch.utils.hooks.RemovableHandle]: grad_out: Dict[device, Tensor] = {} @@ -784,6 +866,7 @@ def forward_hook( nonlocal grad_out grad_out = {} + # pyre-fixme[53]: Captured variable `grad_out` is not annotated. def output_tensor_hook(output_grad: Tensor) -> None: grad_out[output_grad.device] = output_grad @@ -795,7 +878,11 @@ def output_tensor_hook(output_grad: Tensor) -> None: else: out.register_hook(output_tensor_hook) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def pre_hook(module, inp): + # pyre-fixme[53]: Captured variable `module` is not annotated. + # pyre-fixme[3]: Return type must be annotated. def input_tensor_hook(input_grad: Tensor): if len(grad_out) == 0: return @@ -820,6 +907,7 @@ def input_tensor_hook(input_grad: Tensor): ] +# pyre-fixme[3]: Return type must be annotated. def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]): """ Returns the max feature mask index diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 84301e1e82..7c9104d88c 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import threading import typing import warnings @@ -88,9 +90,11 @@ def undo_gradient_requirements( def compute_gradients( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> Tuple[Tensor, ...]: r""" @@ -130,6 +134,7 @@ def _neuron_gradients( inputs: Union[Tensor, Tuple[Tensor, ...]], saved_layer: Dict[device, Tuple[Tensor, ...]], key_list: List[device], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], ) -> Tuple[Tensor, ...]: with torch.autograd.set_grad_enabled(True): @@ -153,10 +158,14 @@ def _neuron_gradients( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all +# possible arguments of overload defined on line `158`. def _forward_layer_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: Module, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, @@ -165,7 +174,10 @@ def _forward_layer_eval( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all +# possible arguments of overload defined on line `170`. def _forward_layer_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: List[Module], @@ -177,6 +189,7 @@ def _forward_layer_eval( def _forward_layer_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: ModuleOrModuleList, @@ -188,6 +201,8 @@ def _forward_layer_eval( return _forward_layer_eval_with_neuron_grads( forward_fn, inputs, + # pyre-fixme[6]: For 3rd argument expected `Module` but got + # `ModuleOrModuleList`. layer, additional_forward_args=additional_forward_args, gradient_neuron_selector=None, @@ -198,20 +213,31 @@ def _forward_layer_eval( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_distributed_eval` does not +# accept all possible arguments of overload defined on line `203`. def _forward_layer_distributed_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. inputs: Any, layer: ModuleOrModuleList, target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_layer_input: bool = False, + # pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. forward_hook_with_return: Literal[False] = False, require_layer_grads: bool = False, ) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ... @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_distributed_eval` does not +# accept all possible arguments of overload defined on line `216`. def _forward_layer_distributed_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Any, layer: ModuleOrModuleList, @@ -219,12 +245,15 @@ def _forward_layer_distributed_eval( additional_forward_args: Any = None, attribute_to_layer_input: bool = False, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. forward_hook_with_return: Literal[True], require_layer_grads: bool = False, ) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ... def _forward_layer_distributed_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Any, layer: ModuleOrModuleList, @@ -249,13 +278,22 @@ def _forward_layer_distributed_eval( """ saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]] = defaultdict(dict) lock = threading.Lock() + # pyre-fixme[9]: all_layers has type `List[Module]`; used as + # `Union[List[Variable[ModuleOrModuleList <: [Module, List[Module]]]], + # Variable[ModuleOrModuleList <: [Module, List[Module]]]]`. all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer # Set a forward hook on specified module and run forward pass to # get layer output tensor(s). # For DataParallel models, each partition adds entry to dictionary # with key as device and value as corresponding Tensor. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def hook_wrapper(original_module): + # pyre-fixme[53]: Captured variable `lock` is not annotated. + # pyre-fixme[53]: Captured variable `original_module` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_hook(module, inp, out=None): eval_tsrs = inp if attribute_to_layer_input else out is_eval_tuple = isinstance(eval_tsrs, tuple) @@ -339,6 +377,7 @@ def _gather_distributed_tensors( def _extract_device_ids( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, saved_layer: Dict[Module, Dict[device, Tuple[Tensor, ...]]], device_ids: Union[None, List[int]], @@ -358,8 +397,10 @@ def _extract_device_ids( ): if ( hasattr(forward_fn, "device_ids") + # pyre-fixme[33]: Given annotation cannot be `Any`. and cast(Any, forward_fn).device_ids is not None ): + # pyre-fixme[33]: Given annotation cannot be `Any`. device_ids = cast(Any, forward_fn).device_ids else: raise AssertionError( @@ -373,12 +414,17 @@ def _extract_device_ids( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does +# not accept all possible arguments of overload defined on line `378`. def _forward_layer_eval_with_neuron_grads( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: Module, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], grad_enabled: bool = False, device_ids: Union[None, List[int]] = None, @@ -387,7 +433,10 @@ def _forward_layer_eval_with_neuron_grads( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does +# not accept all possible arguments of overload defined on line `392`. def _forward_layer_eval_with_neuron_grads( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: Module, @@ -400,7 +449,10 @@ def _forward_layer_eval_with_neuron_grads( @typing.overload +# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does +# not accept all possible arguments of overload defined on line `405`. def _forward_layer_eval_with_neuron_grads( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: List[Module], @@ -413,10 +465,12 @@ def _forward_layer_eval_with_neuron_grads( def _forward_layer_eval_with_neuron_grads( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], layer: ModuleOrModuleList, additional_forward_args: Any = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. gradient_neuron_selector: Union[ None, int, Tuple[Union[int, slice], ...], Callable ] = None, @@ -481,23 +535,32 @@ def _forward_layer_eval_with_neuron_grads( @typing.overload +# pyre-fixme[43]: The implementation of `compute_layer_gradients_and_eval` does not +# accept all possible arguments of overload defined on line `486`. def compute_layer_gradients_and_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, layer: Module, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. output_fn: Union[None, Callable] = None, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], Tuple[Tensor, ...]]: ... @typing.overload +# pyre-fixme[43]: The implementation of `compute_layer_gradients_and_eval` does not +# accept all possible arguments of overload defined on line `502`. def compute_layer_gradients_and_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, layer: List[Module], inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -506,13 +569,17 @@ def compute_layer_gradients_and_eval( gradient_neuron_selector: None = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. output_fn: Union[None, Callable] = None, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[Tuple[Tensor, ...]], List[Tuple[Tensor, ...]]]: ... @typing.overload +# pyre-fixme[43]: The implementation of `compute_layer_gradients_and_eval` does not +# accept all possible arguments of overload defined on line `517`. def compute_layer_gradients_and_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, layer: Module, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -521,22 +588,26 @@ def compute_layer_gradients_and_eval( gradient_neuron_selector: None = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. output_fn: Union[None, Callable] = None, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ... def compute_layer_gradients_and_eval( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, layer: ModuleOrModuleList, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, additional_forward_args: Any = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. gradient_neuron_selector: Union[ None, int, Tuple[Union[int, slice], ...], Callable ] = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. output_fn: Union[None, Callable] = None, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ @@ -603,6 +674,7 @@ def compute_layer_gradients_and_eval( target_ind=target_ind, additional_forward_args=additional_forward_args, attribute_to_layer_input=attribute_to_layer_input, + # pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`. forward_hook_with_return=True, require_layer_grads=True, ) @@ -611,6 +683,8 @@ def compute_layer_gradients_and_eval( " take gradient with respect to multiple outputs." ) + # pyre-fixme[6]: For 2nd argument expected `Dict[Module, Dict[device, + # typing.Tuple[Tensor, ...]]]` but got `Module`. device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids) # Identifies correct device ordering based on device ids. @@ -645,6 +719,9 @@ def compute_layer_gradients_and_eval( ) for single_layer in layer ] + # pyre-fixme[9]: all_layers has type `List[Module]`; used as + # `Union[List[Variable[ModuleOrModuleList <: [Module, List[Module]]]], + # Variable[ModuleOrModuleList <: [Module, List[Module]]]]`. all_layers: List[Module] = [layer] if isinstance(layer, Module) else layer grad_inputs = tuple( layer_tensor @@ -653,6 +730,7 @@ def compute_layer_gradients_and_eval( for layer_tensor in saved_layer[single_layer][device_id] ) saved_grads = torch.autograd.grad( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Module`. outputs=torch.unbind(output), inputs=grad_inputs, **grad_kwargs or {}, @@ -698,14 +776,18 @@ def compute_layer_gradients_and_eval( def construct_neuron_grad_fn( layer: Module, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], device_ids: Union[None, List[int]] = None, attribute_to_neuron_input: bool = False, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. ) -> Callable: def grad_fn( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: TensorOrTupleOfTensorsGeneric, target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> Tuple[Tensor, ...]: _, grads = _forward_layer_eval_with_neuron_grads( @@ -722,6 +804,8 @@ def grad_fn( return grad_fn +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def _extract_parameters_from_layers(layer_modules): layer_parameters = [] if layer_modules is not None: @@ -738,8 +822,10 @@ def _extract_parameters_from_layers(layer_modules): def _compute_jacobian_wrt_params( model: Module, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. inputs: Tuple[Any, ...], labels: Optional[Tensor] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, layer_modules: Optional[List[Module]] = None, ) -> Tuple[Tensor, ...]: @@ -795,6 +881,8 @@ def _compute_jacobian_wrt_params( outputs=out[i], inputs=cast( Union[Tensor, Sequence[Tensor]], + # pyre-fixme[61]: `layer_parameters` is undefined, or not always + # defined. model.parameters() if layer_modules is None else layer_parameters, ), grad_outputs=torch.ones_like(out[i]), @@ -807,10 +895,13 @@ def _compute_jacobian_wrt_params( return tuple(grads) +# pyre-fixme[3]: Return annotation cannot contain `Any`. def _compute_jacobian_wrt_params_with_sample_wise_trick( model: Module, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. inputs: Tuple[Any, ...], labels: Optional[Tensor] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, reduction_type: Optional[str] = "sum", layer_modules: Optional[List[Module]] = None, @@ -902,7 +993,11 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( grads = tuple( param.sample_grad # type: ignore for param in ( - model.parameters() if layer_modules is None else layer_parameters + model.parameters() + if layer_modules is None + # pyre-fixme[61]: `layer_parameters` is undefined, or not always + # defined. + else layer_parameters ) if hasattr(param, "sample_grad") ) diff --git a/captum/_utils/models/__init__.py b/captum/_utils/models/__init__.py index 6c936bc955..3ce0193126 100644 --- a/captum/_utils/models/__init__.py +++ b/captum/_utils/models/__init__.py @@ -1,3 +1,4 @@ +# pyre-strict from captum._utils.models.model import Model __all__ = [ diff --git a/captum/_utils/models/linear_model/__init__.py b/captum/_utils/models/linear_model/__init__.py index d4f50d2146..64b77741ec 100644 --- a/captum/_utils/models/linear_model/__init__.py +++ b/captum/_utils/models/linear_model/__init__.py @@ -1,3 +1,4 @@ +# pyre-strict from captum._utils.models.linear_model.model import ( LinearModel, SGDLasso, diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py index 24302d540c..6008fe983d 100644 --- a/captum/_utils/models/linear_model/model.py +++ b/captum/_utils/models/linear_model/model.py @@ -1,3 +1,4 @@ +# pyre-strict from typing import Callable, cast, List, Optional import torch.nn as nn @@ -9,6 +10,8 @@ class LinearModel(nn.Module, Model): SUPPORTED_NORMS: List[Optional[str]] = [None, "batch_norm", "layer_norm"] + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, train_fn: Callable, **kwargs) -> None: r""" Constructs a linear model with a training function and additional @@ -35,8 +38,10 @@ def __init__(self, train_fn: Callable, **kwargs) -> None: self.norm: Optional[nn.Module] = None self.linear: Optional[nn.Linear] = None self.train_fn = train_fn + # pyre-fixme[4]: Attribute must be annotated. self.construct_kwargs = kwargs + # pyre-fixme[3]: Return type must be annotated. def _construct_model_params( self, in_features: Optional[int] = None, @@ -114,8 +119,11 @@ def _construct_model_params( self.linear.bias.data = bias_value if classes is not None: + # pyre-fixme[16]: `Optional` has no attribute `classes`. self.linear.classes = classes + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): r""" Calls `self.train_fn` @@ -131,6 +139,7 @@ def forward(self, x: Tensor) -> Tensor: assert self.linear is not None if self.norm is not None: x = self.norm(x) + # pyre-fixme[29]: `Optional[nn.modules.linear.Linear]` is not a function. return self.linear(x) def representation(self) -> Tensor: @@ -156,6 +165,7 @@ def classes(self) -> Optional[Tensor]: class SGDLinearModel(LinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Construct a a `LinearModel` with the @@ -174,6 +184,7 @@ def __init__(self, **kwargs) -> None: class SGDLasso(SGDLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class to train a `LinearModel` with SGD @@ -186,6 +197,8 @@ def __init__(self, **kwargs) -> None: """ super().__init__(**kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): # avoid cycles from captum._utils.models.linear_model.train import l2_loss @@ -194,6 +207,7 @@ def fit(self, train_data: DataLoader, **kwargs): class SGDRidge(SGDLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class to train a `LinearModel` with SGD @@ -203,6 +217,8 @@ def __init__(self, **kwargs) -> None: """ super().__init__(**kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): # avoid cycles from captum._utils.models.linear_model.train import l2_loss @@ -211,6 +227,7 @@ def fit(self, train_data: DataLoader, **kwargs): class SGDLinearRegression(SGDLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class to train a `LinearModel` with SGD @@ -219,6 +236,8 @@ def __init__(self, **kwargs) -> None: """ super().__init__(**kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): # avoid cycles from captum._utils.models.linear_model.train import l2_loss @@ -229,6 +248,7 @@ def fit(self, train_data: DataLoader, **kwargs): class SkLearnLinearModel(LinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, sklearn_module: str, **kwargs) -> None: r""" Factory class to construct a `LinearModel` with sklearn training method. @@ -259,6 +279,8 @@ def __init__(self, sklearn_module: str, **kwargs) -> None: self.sklearn_module = sklearn_module + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): r""" Args: @@ -273,6 +295,7 @@ def fit(self, train_data: DataLoader, **kwargs): class SkLearnLasso(SkLearnLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Trains a `LinearModel` model with @@ -281,11 +304,14 @@ def __init__(self, **kwargs) -> None: """ super().__init__(sklearn_module="linear_model.Lasso", **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): return super().fit(train_data=train_data, **kwargs) class SkLearnRidge(SkLearnLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Trains a model with `sklearn.linear_model.Ridge`. @@ -295,11 +321,14 @@ def __init__(self, **kwargs) -> None: """ super().__init__(sklearn_module="linear_model.Ridge", **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): return super().fit(train_data=train_data, **kwargs) class SkLearnLinearRegression(SkLearnLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Trains a model with `sklearn.linear_model.LinearRegression`. @@ -309,11 +338,14 @@ def __init__(self, **kwargs) -> None: """ super().__init__(sklearn_module="linear_model.LinearRegression", **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): return super().fit(train_data=train_data, **kwargs) class SkLearnLogisticRegression(SkLearnLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Trains a model with `sklearn.linear_model.LogisticRegression`. @@ -323,11 +355,14 @@ def __init__(self, **kwargs) -> None: """ super().__init__(sklearn_module="linear_model.LogisticRegression", **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): return super().fit(train_data=train_data, **kwargs) class SkLearnSGDClassifier(SkLearnLinearModel): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, **kwargs) -> None: r""" Factory class. Trains a model with `sklearn.linear_model.SGDClassifier(`. @@ -337,5 +372,7 @@ def __init__(self, **kwargs) -> None: """ super().__init__(sklearn_module="linear_model.SGDClassifier", **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fit(self, train_data: DataLoader, **kwargs): return super().fit(train_data=train_data, **kwargs) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 70b6ef3d5e..64d79153f1 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -1,3 +1,4 @@ +# pyre-strict import time import warnings from typing import Any, Callable, Dict, List, Optional @@ -8,6 +9,8 @@ from torch.utils.data import DataLoader +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def l2_loss(x1, x2, weights=None): if weights is None: return torch.mean((x1 - x2) ** 2) / 2.0 @@ -23,6 +26,7 @@ def sgd_train_linear_model( reduce_lr: bool = True, initial_lr: float = 0.01, alpha: float = 1.0, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Callable = l2_loss, reg_term: Optional[int] = 1, patience: int = 10, @@ -104,6 +108,8 @@ def sgd_train_linear_model( convergence_counter = 0 converged = False + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def get_point(datapoint): if len(datapoint) == 2: x, y = datapoint @@ -137,10 +143,12 @@ def get_point(datapoint): with torch.no_grad(): if init_scheme == "xavier": + # pyre-fixme[16]: `Optional` has no attribute `weight`. torch.nn.init.xavier_uniform_(model.linear.weight) else: model.linear.weight.zero_() + # pyre-fixme[16]: `Optional` has no attribute `bias`. if model.linear.bias is not None: model.linear.bias.zero_() @@ -201,6 +209,7 @@ def get_point(datapoint): loss.backward() optim.step() model.zero_grad() + # pyre-fixme[61]: `scheduler` is undefined, or not always defined. if scheduler: scheduler.step(average_loss) @@ -226,22 +235,30 @@ def get_point(datapoint): class NormLayer(nn.Module): + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, mean, std, n=None, eps=1e-8) -> None: super().__init__() + # pyre-fixme[4]: Attribute must be annotated. self.mean = mean + # pyre-fixme[4]: Attribute must be annotated. self.std = std + # pyre-fixme[4]: Attribute must be annotated. self.eps = eps + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, x): return (x - self.mean) / (self.std + self.eps) +# pyre-fixme[3]: Return type must be annotated. def sklearn_train_linear_model( model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], sklearn_trainer: str = "Lasso", norm_input: bool = False, + # pyre-fixme[2]: Parameter must be annotated. **fit_kwargs, ): r""" @@ -320,6 +337,7 @@ def sklearn_train_linear_model( x /= std t1 = time.time() + # pyre-fixme[29]: `str` is not a function. sklearn_model = reduce( lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") )(**construct_kwargs) @@ -358,6 +376,8 @@ def sklearn_train_linear_model( ) if norm_input: + # pyre-fixme[61]: `mean` is undefined, or not always defined. + # pyre-fixme[61]: `std` is undefined, or not always defined. model.norm = NormLayer(mean, std) return {"train_time": t2 - t1} diff --git a/captum/_utils/models/model.py b/captum/_utils/models/model.py index 9e8a98db04..f6cb6600f0 100644 --- a/captum/_utils/models/model.py +++ b/captum/_utils/models/model.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from abc import ABC, abstractmethod from typing import Dict, Optional, Union @@ -18,7 +20,10 @@ class Model(ABC): @abstractmethod def fit( - self, train_data: DataLoader, **kwargs + self, + train_data: DataLoader, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, ) -> Optional[Dict[str, Union[int, float, Tensor]]]: r""" Override this method to actually train your model. diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index cb08a15aed..1e5891cc80 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import sys import warnings from time import time @@ -21,10 +23,14 @@ def __init__(self, wrapped: TextIO) -> None: """ self._wrapped = wrapped + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def __getattr__(self, name): return getattr(self._wrapped, name) @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _wrapped_run(func, *args, **kwargs): try: return func(*args, **kwargs) @@ -35,9 +41,13 @@ def _wrapped_run(func, *args, **kwargs): if "closed" not in str(e): raise + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def write(self, *args, **kwargs): return self._wrapped_run(self._wrapped.write, *args, **kwargs) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def flush(self, *args, **kwargs): return self._wrapped_run(self._wrapped.flush, *args, **kwargs) @@ -51,25 +61,37 @@ class NullProgress: progress bars. """ + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs): del args, kwargs self.iterable = iterable + # pyre-fixme[3]: Return type must be annotated. def __enter__(self): return self + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: + # pyre-fixme[7]: Expected `Literal[]` but got `bool`. return False + # pyre-fixme[3]: Return type must be annotated. def __iter__(self): if not self.iterable: return + # pyre-fixme[16]: `Optional` has no attribute `__iter__`. for it in self.iterable: yield it + # pyre-fixme[3]: Return type must be annotated. def update(self, amount: int = 1): pass + # pyre-fixme[3]: Return type must be annotated. def close(self): pass @@ -77,6 +99,7 @@ def close(self): class SimpleProgress: def __init__( self, + # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. iterable: Optional[Iterable] = None, desc: Optional[str] = None, total: Optional[int] = None, @@ -99,6 +122,8 @@ def __init__( self.desc = desc + # pyre-fixme[9]: file has type `Optional[TextIO]`; used as + # `DisableErrorIOWrapper`. file = DisableErrorIOWrapper(file if file else sys.stderr) cast(TextIO, file) self.file = file @@ -108,28 +133,38 @@ def __init__( self.closed = False self._is_parent = False + # pyre-fixme[3]: Return type must be annotated. def __enter__(self): self._is_parent = True self._refresh() return self + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: self.close() + # pyre-fixme[7]: Expected `Literal[]` but got `bool`. return False + # pyre-fixme[3]: Return type must be annotated. def __iter__(self): if self.closed or not self.iterable: return self._refresh() + # pyre-fixme[16]: `Optional` has no attribute `__iter__`. for it in self.iterable: yield it self.update() self.close() + # pyre-fixme[3]: Return type must be annotated. def _refresh(self): progress_str = self.desc + ": " if self.desc else "" if self.total: # e.g., progress: 60% 3/5 + # pyre-fixme[58]: `//` is not supported for operand types `int` and + # `Optional[int]`. progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}" else: # e.g., progress: ..... @@ -137,6 +172,7 @@ def _refresh(self): end = "\n" if self._is_parent else "" print("\r" + progress_str, end=end, file=self.file) + # pyre-fixme[3]: Return type must be annotated. def update(self, amount: int = 1): if self.closed: return @@ -147,6 +183,7 @@ def update(self, amount: int = 1): self._refresh() self.last_print_t = cur_t + # pyre-fixme[3]: Return type must be annotated. def close(self): if not self.closed and not self._is_parent: self._refresh() @@ -154,13 +191,17 @@ def close(self): self.closed = True +# pyre-fixme[3]: Return type must be annotated. def progress( + # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. iterable: Optional[Iterable] = None, desc: Optional[str] = None, total: Optional[int] = None, + # pyre-fixme[2]: Parameter must be annotated. use_tqdm=True, file: Optional[TextIO] = None, mininterval: float = 0.5, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ): # Try to use tqdm is possible. Fall back to simple progress print diff --git a/captum/_utils/sample_gradient.py b/captum/_utils/sample_gradient.py index 660c0030a7..7b868b9cd8 100644 --- a/captum/_utils/sample_gradient.py +++ b/captum/_utils/sample_gradient.py @@ -1,3 +1,4 @@ +# pyre-strict from collections import defaultdict from enum import Enum from typing import cast, DefaultDict, Iterable, List, Optional, Tuple, Union @@ -58,6 +59,7 @@ def conv2d_param_grads( if reset: _reset_sample_grads(module) + # pyre-fixme[22]: The cast is redundant. batch_size = cast(int, activation.shape[0]) unfolded_act = torch.nn.functional.unfold( activation, @@ -100,7 +102,9 @@ class SampleGradientWrapper: - https://github.com/pytorch/opacus/tree/main/opacus/grad_sample """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, model, layer_modules=None) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.model = model self.hooks_added = False self.activation_dict: DefaultDict[Module, List[Tensor]] = defaultdict(list) @@ -158,6 +162,7 @@ def _reset(self) -> None: self.activation_dict = defaultdict(list) self.gradient_dict = defaultdict(list) + # pyre-fixme[2]: Parameter must be annotated. def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None: assert ( loss_mode.upper() in LossMode.__members__ @@ -168,6 +173,8 @@ def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None: loss_blob.backward(gradient=torch.ones_like(loss_blob)) for module in self.gradient_dict: + # pyre-fixme[6]: For 1st argument expected `Type[Union[Conv2d, Linear]]` + # but got `Type[Module]`. sample_grad_fn = SUPPORTED_MODULES[type(module)] activations = self.activation_dict[module] gradients = self.gradient_dict[module] diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 6d92511e82..d9ac6304c8 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union from torch import Tensor @@ -13,6 +15,7 @@ TensorOrTupleOfTensorsGeneric = TypeVar( "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] ) +# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py index 612486e573..a33cd862dd 100644 --- a/captum/attr/__init__.py +++ b/captum/attr/__init__.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from captum.attr._core.dataloader_attr import DataLoaderAttribution from captum.attr._core.deep_lift import DeepLift, DeepLiftShap from captum.attr._core.feature_ablation import FeatureAblation diff --git a/captum/attr/_core/dataloader_attr.py b/captum/attr/_core/dataloader_attr.py index 1e2aa6f5d7..2bbacdec12 100644 --- a/captum/attr/_core/dataloader_attr.py +++ b/captum/attr/_core/dataloader_attr.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from collections import defaultdict from copy import copy from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -28,6 +30,8 @@ class InputRole: # default reducer wehn reduce is None. Simply concat the outputs by the batch dimension +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def _concat_tensors(accum, cur_output, _): return cur_output if accum is None else torch.cat([accum, cur_output]) @@ -58,7 +62,9 @@ def _create_perturbation_mask( return perturbation_mask +# pyre-fixme[3]: Return annotation cannot contain `Any`. def _perturb_inputs( + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. inputs: Iterable[Any], input_roles: Tuple[int], baselines: Tuple[Union[int, float, Tensor], ...], @@ -113,6 +119,8 @@ def _convert_output_shape( for inp, mask in zip(attr_inputs, feature_mask): # input in shape(batch_size, *inp_feature_dims) # attribute in shape(*output_dims, *inp_feature_dims) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*output_dims, *inp.shape[slice(1, None, None)]`. attr_shape = (*output_dims, *inp.shape[1:]) expanded_feature_indices = mask.expand(attr_shape) @@ -125,7 +133,11 @@ def _convert_output_shape( # (*output_dims, 1..., 1, n_features) # then broadcast to (*output_dims, *inp.shape[1:-1], n_features) n_extra_dims = len(extra_inp_dims) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*output_dims, *(1).__mul__(n_extra_dims)`. unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_features) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*output_dims, *extra_inp_dims`. expanded_shape = (*output_dims, *extra_inp_dims, n_features) expanded_unqiue_attr = unique_attr.reshape(unsqueezed_shape).expand( expanded_shape @@ -168,10 +180,12 @@ def __init__(self, attr_method: Attribution) -> None: super().__init__(attr_method.forward_func) # shallow copy is enough to avoid modifying original instance + # pyre-fixme[4]: Attribute must be annotated. self.attr_method = copy(attr_method) self.attr_method.forward_func = self._forward_with_dataloader + # pyre-fixme[3]: Return type must be annotated. def _forward_with_dataloader( self, batched_perturbed_feature_indices: Tensor, @@ -179,7 +193,9 @@ def _forward_with_dataloader( input_roles: Tuple[int], baselines: Tuple[Union[int, float, Tensor], ...], feature_mask: Tuple[Tensor, ...], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. reduce: Callable, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. to_metric: Optional[Callable], show_progress: bool, feature_idx_to_mask_idx: Dict[int, List[int]], @@ -250,7 +266,9 @@ def attribute( input_roles: Optional[Tuple[int, ...]] = None, baselines: BaselineType = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. reduce: Optional[Callable] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. to_metric: Optional[Callable] = None, perturbations_per_pass: int = 1, show_progress: bool = False, @@ -347,6 +365,9 @@ def attribute( inputs = _format_tensor_into_tuples(inputs) if input_roles: + # pyre-fixme[6]: For 1st argument expected + # `pyre_extensions.ReadOnly[Sized]` but got + # `Optional[typing.Tuple[typing.Any, ...]]`. assert len(input_roles) == len(inputs), ( "input_roles must have the same size as the return of the dataloader,", f"length of input_roles is {len(input_roles)} ", @@ -359,10 +380,15 @@ def attribute( ) else: # by default, assume every element in the dataloader needs attribution + # pyre-fixme[16]: `Optional` has no attribute `__iter__`. input_roles = tuple(InputRole.need_attr for _ in inputs) attr_inputs = tuple( - inp for role, inp in zip(input_roles, inputs) if role == InputRole.need_attr + inp + # pyre-fixme[6]: For 2nd argument expected `Iterable[Variable[_T2]]` but + # got `Optional[typing.Tuple[typing.Any, ...]]`. + for role, inp in zip(input_roles, inputs) + if role == InputRole.need_attr ) baselines = _format_baseline(baselines, attr_inputs) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index eea8234eef..0669fe6dcf 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing import warnings from typing import Any, Callable, cast, List, Tuple, Union @@ -115,17 +117,25 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `120`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> TensorOrTupleOfTensorsGeneric: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `131`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -133,6 +143,8 @@ def attribute( target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @@ -289,13 +301,23 @@ def attribute( # type: ignore # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. baselines = _format_baseline(baselines, inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. _validate_input(inputs, baselines) # set hooks for baselines @@ -304,6 +326,8 @@ def attribute( # type: ignore activations. The hooks and attributes will be removed after the attribution is finished""" ) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. baselines = _tensorize_baseline(inputs, baselines) main_model_hooks = [] try: @@ -338,13 +362,21 @@ def attribute( # type: ignore attributions = gradients else: attributions = _call_custom_attribution_func( - custom_attribution_func, gradients, inputs, baselines + custom_attribution_func, + gradients, + # pyre-fixme[6]: For 3rd argument expected `Tuple[Tensor, ...]` + # but got `TensorOrTupleOfTensorsGeneric`. + inputs, + baselines, ) finally: # Even if any error is raised, remove all hooks before raising self._remove_hooks(main_model_hooks) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric... return _compute_conv_delta_and_format_attrs( self, return_convergence_delta, @@ -358,17 +390,26 @@ def attribute( # type: ignore def _construct_forward_func( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. inputs: Tuple, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. ) -> Callable: + # pyre-fixme[3]: Return type must be annotated. def forward_fn(): model_out = _run_forward( forward_func, inputs, None, additional_forward_args ) return _select_targets( - torch.cat((model_out[:, 0], model_out[:, 1])), target + # pyre-fixme[16]: Item `Future` of + # `Union[Future[torch._tensor.Tensor], Tensor]` has no attribute + # `__getitem__`. + torch.cat((model_out[:, 0], model_out[:, 1])), + target, ) if hasattr(forward_func, "device_ids"): @@ -489,6 +530,7 @@ def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None: backward_handle.remove() def _hook_main_model(self) -> List[RemovableHandle]: + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: inputs = baseline_inputs_add_args[0] baselines = baseline_inputs_add_args[1] @@ -502,14 +544,20 @@ def pre_hook(module: Module, baseline_inputs_add_args: Tuple) -> Tuple: ) if additional_args is not None: expanded_additional_args = cast( + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type + # parameter. Tuple, _expand_additional_forward_args( additional_args, 2, ExpansionTypes.repeat ), ) + # pyre-fixme[60]: Concatenation not yet support for multiple + # variadic tuples: `*baseline_input_tsr, *expanded_additional_args`. return (*baseline_input_tsr, *expanded_additional_args) return baseline_input_tsr + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. def forward_hook(module: Module, inputs: Tuple, outputs: Tensor): return torch.stack(torch.chunk(outputs, 2), dim=1) @@ -530,6 +578,7 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -579,6 +628,8 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: # There's a mismatch between the signatures of DeepLift.attribute and # DeepLiftShap.attribute, so we ignore typing here @typing.overload # type: ignore + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `584`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -586,12 +637,18 @@ def attribute( TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> TensorOrTupleOfTensorsGeneric: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `597`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -601,6 +658,8 @@ def attribute( target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @@ -753,8 +812,15 @@ def attribute( # type: ignore >>> # Computes shap values using deeplift for class 3. >>> attribution = dl.attribute(input, target=3) """ + # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., + # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, + # ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. baselines = _format_callable_baseline(baselines, inputs) + # pyre-fixme[16]: Item `Callable` of `Union[(...) -> + # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no + # attribute `__getitem__`. assert isinstance(baselines[0], torch.Tensor) and baselines[0].shape[0] > 1, ( "Baselines distribution has to be provided in form of a torch.Tensor" " with more than one example but found: {}." @@ -765,12 +831,19 @@ def attribute( # type: ignore # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) # batch sizes inp_bsz = inputs[0].shape[0] + # pyre-fixme[16]: Item `Callable` of `Union[(...) -> + # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no + # attribute `__getitem__`. base_bsz = baselines[0].shape[0] ( @@ -779,7 +852,15 @@ def attribute( # type: ignore exp_tgt, exp_addit_args, ) = self._expand_inputs_baselines_targets( - baselines, inputs, target, additional_forward_args + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got `... + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. + baselines, + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. + inputs, + target, + additional_forward_args, ) attributions = super().attribute.__wrapped__( # type: ignore self, @@ -788,7 +869,12 @@ def attribute( # type: ignore target=exp_tgt, additional_forward_args=exp_addit_args, return_convergence_delta=cast( - Literal[True, False], return_convergence_delta + # pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid + # type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take + # parameters. + Literal[True, False], + return_convergence_delta, ), custom_attribution_func=custom_attribution_func, ) @@ -803,15 +889,20 @@ def attribute( # type: ignore ) if return_convergence_delta: + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... + # pyre-fixme[61]: `delta` is undefined, or not always defined. return _format_output(is_inputs_tuple, attributions), delta else: + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... return _format_output(is_inputs_tuple, attributions) + # pyre-fixme[3]: Return annotation cannot contain `Any`. def _expand_inputs_baselines_targets( self, baselines: Tuple[Tensor, ...], inputs: Tuple[Tensor, ...], target: TargetType, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any]: inp_bsz = inputs[0].shape[0] @@ -854,8 +945,11 @@ def _compute_mean_across_baselines( self, inp_bsz: int, base_bsz: int, attribution: Tensor ) -> Tensor: # Average for multiple references + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. attr_shape: Tuple = (inp_bsz, base_bsz) if len(attribution.shape) > 1: + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, + # int]` and `Size`. attr_shape += attribution.shape[1:] return torch.mean(attribution.view(attr_shape), dim=1, keepdim=False) @@ -882,6 +976,7 @@ def nonlinear( return new_grad_inp +# pyre-fixme[3]: Return type must be annotated. def softmax( module: Module, inputs: Tensor, @@ -903,6 +998,7 @@ def softmax( return new_grad_inp +# pyre-fixme[3]: Return type must be annotated. def maxpool1d( module: Module, inputs: Tensor, @@ -923,6 +1019,7 @@ def maxpool1d( ) +# pyre-fixme[3]: Return type must be annotated. def maxpool2d( module: Module, inputs: Tensor, @@ -943,8 +1040,18 @@ def maxpool2d( ) +# pyre-fixme[3]: Return type must be annotated. def maxpool3d( - module: Module, inputs, outputs, grad_input, grad_output, eps: float = 1e-10 + module: Module, + # pyre-fixme[2]: Parameter must be annotated. + inputs, + # pyre-fixme[2]: Parameter must be annotated. + outputs, + # pyre-fixme[2]: Parameter must be annotated. + grad_input, + # pyre-fixme[2]: Parameter must be annotated. + grad_output, + eps: float = 1e-10, ): return maxpool( module, @@ -958,13 +1065,20 @@ def maxpool3d( ) +# pyre-fixme[3]: Return type must be annotated. def maxpool( module: Module, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. pool_func: Callable, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. unpool_func: Callable, + # pyre-fixme[2]: Parameter must be annotated. inputs, + # pyre-fixme[2]: Parameter must be annotated. outputs, + # pyre-fixme[2]: Parameter must be annotated. grad_input, + # pyre-fixme[2]: Parameter must be annotated. grad_output, eps: float = 1e-10, ): @@ -1037,6 +1151,7 @@ def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]: return torch.cat(2 * [delta_in]), torch.cat(2 * [delta_out]) +# pyre-fixme[5]: Global expression must be annotated. SUPPORTED_NON_LINEAR = { nn.ReLU: nonlinear, nn.ELU: nonlinear, diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 891b893f0f..b94879ec92 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import math from typing import Any, Callable, cast, List, Optional, Tuple, Union @@ -44,6 +46,7 @@ class FeatureAblation(PerturbationAttribution): first dimension (i.e. a feature mask requires to be applied to all inputs). """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -71,6 +74,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, @@ -258,12 +262,18 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs, baselines = _format_input_baseline(inputs, baselines) additional_forward_args = _format_additional_forward_args( additional_forward_args ) num_examples = inputs[0].shape[0] + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. feature_mask = _format_feature_mask(feature_mask, inputs) assert ( @@ -336,6 +346,7 @@ def attribute( # The will be the same amount futures as modified_eval down there, # since we cannot add up the evaluation result adhoc under async mode. + # pyre-fixme[24]: Generic type `Future` expects 1 type parameter. all_futures: List[List[Future]] = [[] for _ in range(len(inputs))] # Iterate through each feature tensor for ablation for i in range(len(inputs)): @@ -443,15 +454,24 @@ def attribute( else: return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long + # pyre-fixme[3]: Return type must be annotated. def _ith_input_ablation_generator( self, + # pyre-fixme[2]: Parameter must be annotated. i, + # pyre-fixme[2]: Parameter must be annotated. inputs, + # pyre-fixme[2]: Parameter must be annotated. additional_args, + # pyre-fixme[2]: Parameter must be annotated. target, + # pyre-fixme[2]: Parameter must be annotated. baselines, + # pyre-fixme[2]: Parameter must be annotated. input_mask, + # pyre-fixme[2]: Parameter must be annotated. perturbations_per_eval, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ): """ @@ -477,6 +497,8 @@ def _ith_input_ablation_generator( perturbations_per_eval = min(perturbations_per_eval, num_features) baseline = baselines[i] if isinstance(baselines, tuple) else baselines if isinstance(baseline, torch.Tensor): + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` + # and `Size`. baseline = baseline.reshape((1,) + baseline.shape) if perturbations_per_eval > 1: @@ -556,8 +578,21 @@ def _ith_input_ablation_generator( current_features[i] = original_tensor num_features_processed += current_num_ablated_features + # pyre-fixme[3]: Return type must be annotated. def _construct_ablated_input( - self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs + self, + # pyre-fixme[2]: Parameter must be annotated. + expanded_input, + # pyre-fixme[2]: Parameter must be annotated. + input_mask, + # pyre-fixme[2]: Parameter must be annotated. + baseline, + # pyre-fixme[2]: Parameter must be annotated. + start_feature, + # pyre-fixme[2]: Parameter must be annotated. + end_feature, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, ): r""" Ablates given expanded_input tensor with given feature mask, feature range, @@ -584,6 +619,8 @@ def _construct_ablated_input( ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_feature_range_and_mask(self, input, input_mask, **kwargs): if input_mask is None: # Obtain feature mask for selected input tensor, matches size of @@ -598,6 +635,8 @@ def _get_feature_range_and_mask(self, input, input_mask, **kwargs): input_mask, ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_feature_counts(self, inputs, feature_mask, **kwargs): """return the numbers of input features""" if not feature_mask: @@ -612,6 +651,7 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs): for inp, mask in zip(inputs, feature_mask) ) + # pyre-fixme[2]: Parameter must be annotated. def _parse_forward_out(self, forward_output) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output @@ -678,18 +718,31 @@ def _process_initial_eval( def _process_ablated_out( self, + # pyre-fixme[2]: Parameter must be annotated. modified_eval, + # pyre-fixme[2]: Parameter must be annotated. current_inputs, + # pyre-fixme[2]: Parameter must be annotated. current_mask, + # pyre-fixme[2]: Parameter must be annotated. perturbations_per_eval, + # pyre-fixme[2]: Parameter must be annotated. num_examples, + # pyre-fixme[2]: Parameter must be annotated. initial_eval, + # pyre-fixme[2]: Parameter must be annotated. flattened_initial_eval, + # pyre-fixme[2]: Parameter must be annotated. inputs, + # pyre-fixme[2]: Parameter must be annotated. n_outputs, + # pyre-fixme[2]: Parameter must be annotated. total_attrib, + # pyre-fixme[2]: Parameter must be annotated. weights, + # pyre-fixme[2]: Parameter must be annotated. i, + # pyre-fixme[2]: Parameter must be annotated. attrib_type, ) -> Tuple[List[Tensor], List[Tensor]]: modified_eval = self._parse_forward_out(modified_eval) @@ -749,6 +802,7 @@ def _generate_async_result( ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: # Each element of the 2d list contains evalutaion results for a feature # Need to add up all the results for each input + # pyre-fixme[24]: Generic type `Future` expects 1 type parameter. accumulate_fut_list: List[Future] = [] total_attrib: List[Tensor] = [] weights: List[Tensor] = [] diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index a3e7580780..315307bd11 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, Tuple, Union import torch @@ -71,7 +73,11 @@ class FeaturePermutation(FeatureAblation): """ def __init__( - self, forward_func: Callable, perm_func: Callable = _permute_feature + self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + forward_func: Callable, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + perm_func: Callable = _permute_feature, ) -> None: r""" Args: @@ -96,6 +102,7 @@ def attribute( # type: ignore self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, @@ -277,6 +284,7 @@ def attribute_future( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, diff --git a/captum/attr/_core/gradient_shap.py b/captum/attr/_core/gradient_shap.py index aa3b0a281e..c42e6e78e0 100644 --- a/captum/attr/_core/gradient_shap.py +++ b/captum/attr/_core/gradient_shap.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from typing import Any, Callable, Tuple, Union @@ -55,6 +57,7 @@ class GradientShap(GradientAttribution): samples and compute the expectation (smoothgrad). """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None: r""" Args: @@ -79,6 +82,8 @@ def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> N self._multiply_by_inputs = multiply_by_inputs @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `84`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -88,12 +93,17 @@ def attribute( n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `99`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -104,10 +114,15 @@ def attribute( stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -250,7 +265,14 @@ def attribute( """ # since `baselines` is a distribution, we can generate it using a function # rather than passing it as an input argument + # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., + # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, + # ...]]]], Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. baselines = _format_callable_baseline(baselines, inputs) + # pyre-fixme[16]: Item `Callable` of `Union[(...) -> + # TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]` has no + # attribute `__getitem__`. assert isinstance(baselines[0], torch.Tensor), ( "Baselines distribution has to be provided in a form " "of a torch.Tensor {}.".format(baselines[0]) @@ -283,11 +305,14 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs class InputBaselineXGradient(GradientAttribution): + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None: r""" Args: @@ -310,26 +335,37 @@ def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None: """ GradientAttribution.__init__(self, forward_func) + # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = multiply_by_inputs @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `318`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `329`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -346,7 +382,11 @@ def attribute( # type: ignore ]: # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs, baselines = _format_input_baseline(inputs, baselines) rand_coefficient = torch.tensor( @@ -374,6 +414,7 @@ def attribute( # type: ignore else: attributions = grads + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric... return _compute_conv_delta_and_format_attrs( self, return_convergence_delta, @@ -389,6 +430,7 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index f7e19b3583..d8ba7ca353 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import warnings from typing import Any, List, Tuple, Union @@ -43,6 +45,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -55,9 +58,15 @@ def attribute( # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) # set hooks for overriding ReLU gradients @@ -74,14 +83,20 @@ def attribute( finally: self._remove_hooks() + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, gradients) + # pyre-fixme[3]: Return type must be annotated. def _register_hooks(self, module: Module): if isinstance(module, torch.nn.ReLU): hooks = _register_backward_hook(module, self._backward_hook, self) self.backward_hooks.extend(hooks) + # pyre-fixme[3]: Return type must be annotated. def _backward_hook( self, module: Module, @@ -96,6 +111,7 @@ def _backward_hook( else: return F.relu(to_override_grads) + # pyre-fixme[3]: Return type must be annotated. def _remove_hooks(self): for hook in self.backward_hooks: hook.remove() @@ -132,6 +148,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -241,6 +258,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" diff --git a/captum/attr/_core/guided_grad_cam.py b/captum/attr/_core/guided_grad_cam.py index 32080b17a0..5a7424437c 100644 --- a/captum/attr/_core/guided_grad_cam.py +++ b/captum/attr/_core/guided_grad_cam.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import warnings from typing import Any, List, Union @@ -70,6 +72,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, interpolate_mode: str = "nearest", attribute_to_layer_input: bool = False, @@ -178,7 +181,11 @@ def attribute( >>> # attribution size matches input size, Nx3x32x32 >>> attribution = guided_gc.attribute(input, 3) """ + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) grad_cam_attr = self.grad_cam.attribute.__wrapped__( self.grad_cam, # self @@ -208,6 +215,8 @@ def attribute( guided_backprop_attr[i] * LayerAttribution.interpolate( grad_cam_attr, + # pyre-fixme[6]: For 2nd argument expected `Union[int, + # typing.Tuple[int, ...]]` but got `Size`. inputs[i].shape[2:], interpolate_mode=interpolate_mode, ) @@ -220,4 +229,6 @@ def attribute( ) output_attr.append(torch.empty(0)) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, tuple(output_attr)) diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py index 1fbcf2b045..b5ba57f4cb 100644 --- a/captum/attr/_core/input_x_gradient.py +++ b/captum/attr/_core/input_x_gradient.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple @@ -18,6 +20,7 @@ class InputXGradient(GradientAttribution): https://arxiv.org/abs/1605.01713 """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -32,6 +35,7 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -109,9 +113,15 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) gradients = self.gradient_func( @@ -122,9 +132,14 @@ def attribute( input * gradient for input, gradient in zip(inputs, gradients) ) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return True diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index 7f0aa28705..fb141310dd 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from typing import Any, Callable, List, Tuple, Union @@ -47,6 +49,7 @@ class IntegratedGradients(GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, multiply_by_inputs: bool = True, ) -> None: @@ -77,19 +80,27 @@ def __init__( # and when return_convergence_delta is True, the return type is # a tuple with both attributions and deltas. @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `82`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `95`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -100,6 +111,8 @@ def attribute( method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @@ -261,10 +274,16 @@ def attribute( # type: ignore """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs, baselines = _format_input_baseline(inputs, baselines) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. _validate_input(inputs, baselines, n_steps, method) if internal_batch_size is not None: @@ -282,6 +301,8 @@ def attribute( # type: ignore ) else: attributions = self._attribute( + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but + # got `TensorOrTupleOfTensorsGeneric`. inputs=inputs, baselines=baselines, target=target, @@ -300,7 +321,12 @@ def attribute( # type: ignore additional_forward_args=additional_forward_args, target=target, ) + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... return _format_output(is_inputs_tuple, attributions), delta + # pyre-fixme[7]: Expected + # `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]], Tensor], Variable[TensorOrTupleOfTensorsGeneric + # <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) def _attribute( @@ -308,6 +334,7 @@ def _attribute( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", @@ -385,5 +412,6 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index 81d7b5947f..081d540456 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from typing import Any, Callable, Generator, Tuple, Union import torch @@ -25,6 +27,7 @@ class KernelShap(Lime): https://arxiv.org/abs/1705.07874 """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -47,6 +50,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, @@ -291,7 +295,12 @@ def attribute( # type: ignore ) def kernel_shap_similarity_kernel( - self, _, __, interpretable_sample: Tensor, **kwargs + self, + _, + __, + interpretable_sample: Tensor, + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, ) -> Tensor: assert ( "num_interp_features" in kwargs @@ -311,7 +320,10 @@ def kernel_shap_similarity_kernel( return torch.tensor([similarities]) def kernel_shap_perturb_generator( - self, original_inp: Union[Tensor, Tuple[Tensor, ...]], **kwargs + self, + original_inp: Union[Tensor, Tuple[Tensor, ...]], + # pyre-fixme[2]: Parameter must be annotated. + **kwargs, ) -> Generator[Tensor, None, None]: r""" Perturbations are sampled by the following process: diff --git a/captum/attr/_core/layer/grad_cam.py b/captum/attr/_core/layer/grad_cam.py index 554ce8bcb6..01c14a405f 100644 --- a/captum/attr/_core/layer/grad_cam.py +++ b/captum/attr/_core/layer/grad_cam.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -52,6 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -79,6 +82,7 @@ def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_layer_input: bool = False, relu_attributions: bool = False, @@ -210,7 +214,10 @@ def attribute( summed_grads = tuple( ( torch.mean( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Tuple[Tensor, ...]`. layer_grad, + # pyre-fixme[16]: `tuple` has no attribute `shape`. dim=tuple(x for x in range(2, len(layer_grad.shape))), keepdim=True, ) @@ -222,15 +229,27 @@ def attribute( if attr_dim_summation: scaled_acts = tuple( + # pyre-fixme[58]: `*` is not supported for operand types + # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and + # `Tuple[Tensor, ...]`. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Tuple[Tensor, ...]`. torch.sum(summed_grad * layer_eval, dim=1, keepdim=True) for summed_grad, layer_eval in zip(summed_grads, layer_evals) ) else: scaled_acts = tuple( + # pyre-fixme[58]: `*` is not supported for operand types + # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and + # `Tuple[Tensor, ...]`. summed_grad * layer_eval for summed_grad, layer_eval in zip(summed_grads, layer_evals) ) if relu_attributions: + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[tuple[Tensor], Tensor]`. scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts) + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `Tuple[Union[tuple[Tensor], Tensor], ...]`. return _format_output(len(scaled_acts) > 1, scaled_acts) diff --git a/captum/attr/_core/layer/internal_influence.py b/captum/attr/_core/layer/internal_influence.py index e9594efebd..548d8d2282 100644 --- a/captum/attr/_core/layer/internal_influence.py +++ b/captum/attr/_core/layer/internal_influence.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -39,6 +41,7 @@ class InternalInfluence(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -69,6 +72,7 @@ def attribute( inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", @@ -250,6 +254,7 @@ def _attribute( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", @@ -301,7 +306,9 @@ def _attribute( # flattening grads so that we can multiply it with step-size # calling contiguous to avoid `memory whole` problems scaled_grads = tuple( + # pyre-fixme[16]: `tuple` has no attribute `contiguous`. layer_grad.contiguous().view(n_steps, -1) + # pyre-fixme[16]: `tuple` has no attribute `device`. * torch.tensor(step_sizes).view(n_steps, 1).to(layer_grad.device) for layer_grad in layer_gradients ) @@ -309,7 +316,11 @@ def _attribute( # aggregates across all steps for each tensor in the input tuple attrs = tuple( _reshape_and_sum( - scaled_grad, n_steps, inputs[0].shape[0], layer_grad.shape[1:] + scaled_grad, + n_steps, + inputs[0].shape[0], + # pyre-fixme[16]: `tuple` has no attribute `shape`. + layer_grad.shape[1:], ) for scaled_grad, layer_grad in zip(scaled_grads, layer_gradients) ) diff --git a/captum/attr/_core/layer/layer_activation.py b/captum/attr/_core/layer/layer_activation.py index c29aa4e111..99f1e951d6 100644 --- a/captum/attr/_core/layer/layer_activation.py +++ b/captum/attr/_core/layer/layer_activation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union import torch @@ -18,6 +20,7 @@ class LayerActivation(LayerAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, @@ -48,6 +51,7 @@ def __init__( def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_layer_input: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: @@ -127,10 +131,13 @@ def attribute( return _format_output(len(layer_eval) > 1, layer_eval) else: return [ + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but + # got `Tensor`. _format_output(len(single_layer_eval) > 1, single_layer_eval) for single_layer_eval in layer_eval ] @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return True diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py index 18b26d7514..f353dbe32c 100644 --- a/captum/attr/_core/layer/layer_conductance.py +++ b/captum/attr/_core/layer/layer_conductance.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -42,6 +44,7 @@ class LayerConductance(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -70,22 +73,29 @@ def has_convergence_delta(self) -> bool: return True @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `75`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `91`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -95,12 +105,17 @@ def attribute( n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -323,6 +338,7 @@ def _attribute( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", @@ -380,6 +396,8 @@ def _attribute( # This approximates the total input gradient of each step multiplied # by the step size. grad_diffs = tuple( + # pyre-fixme[58]: `-` is not supported for operand types `Tuple[Tensor, + # ...]` and `Tuple[Tensor, ...]`. layer_eval[num_examples:] - layer_eval[:-num_examples] for layer_eval in layer_evals ) @@ -392,6 +410,7 @@ def _attribute( grad_diff * layer_gradient[:-num_examples], n_steps, num_examples, + # pyre-fixme[16]: `tuple` has no attribute `shape`. layer_eval.shape[1:], ) for layer_gradient, layer_eval, grad_diff in zip( @@ -401,5 +420,6 @@ def _attribute( return _format_output(len(attributions) > 1, attributions) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return True diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index a3f35c6d43..4128570abe 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union @@ -99,12 +101,18 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLift @typing.overload # type: ignore + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `104`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -112,6 +120,8 @@ def attribute( ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `117`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -119,6 +129,8 @@ def attribute( target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -126,6 +138,8 @@ def attribute( ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -322,6 +336,7 @@ def attribute( additional_forward_args, ) + # pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter. def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: if isinstance(out, Tensor): return out.chunk(2) @@ -366,10 +381,13 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: inputs, additional_forward_args, target, + # pyre-fixme[31]: Expression `Literal[False])]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. cast(Union[Literal[True], Literal[False]], len(attributions) > 1), ) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -434,6 +452,8 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLiftShap @typing.overload # type: ignore + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `439`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -441,13 +461,19 @@ def attribute( Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]] ], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `453`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -457,12 +483,16 @@ def attribute( target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -655,7 +685,12 @@ def attribute( target=exp_target, additional_forward_args=exp_addit_args, return_convergence_delta=cast( - Literal[True, False], return_convergence_delta + # pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid + # type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take + # parameters. + Literal[True, False], + return_convergence_delta, ), attribute_to_layer_input=attribute_to_layer_input, custom_attribution_func=custom_attribution_func, @@ -674,10 +709,15 @@ def attribute( self, inp_bsz, base_bsz, attributions ) if return_convergence_delta: + # pyre-fixme[61]: `delta` is undefined, or not always defined. return attributions, delta else: + # pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor, + # typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]` + # but got `Union[tuple[Tensor], Tensor]`. return attributions @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_feature_ablation.py b/captum/attr/_core/layer/layer_feature_ablation.py index 855c84faf8..6759a1b186 100644 --- a/captum/attr/_core/layer/layer_feature_ablation.py +++ b/captum/attr/_core/layer/layer_feature_ablation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union import torch @@ -35,6 +37,7 @@ class LayerFeatureAblation(LayerAttribution, PerturbationAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -67,6 +70,7 @@ def attribute( inputs: Union[Tensor, Tuple[Tensor, ...]], layer_baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, layer_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, attribute_to_layer_input: bool = False, @@ -221,6 +225,8 @@ def attribute( >>> layer_mask=layer_mask) """ + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def layer_forward_func(*args): layer_length = args[-1] layer_input = args[:layer_length] @@ -238,6 +244,9 @@ def layer_forward_func(*args): else: all_layer_inputs[layer_input[0].device] = layer_input + # pyre-fixme[53]: Captured variable `all_layer_inputs` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_hook(module, inp, out=None): device = _extract_device(module, inp, out) is_layer_tuple = ( @@ -302,5 +311,6 @@ def forward_hook(module, inp, out=None): return _attr @property + # pyre-fixme[3]: Return type must be annotated. def attributor(self): return FeatureAblation diff --git a/captum/attr/_core/layer/layer_feature_permutation.py b/captum/attr/_core/layer/layer_feature_permutation.py index 246967ab28..89bae65f83 100644 --- a/captum/attr/_core/layer/layer_feature_permutation.py +++ b/captum/attr/_core/layer/layer_feature_permutation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, cast, List, Tuple, Union import torch @@ -30,6 +32,7 @@ class LayerFeaturePermutation(LayerAttribution, FeaturePermutation): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -61,6 +64,7 @@ def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, layer_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, @@ -155,6 +159,7 @@ def attribute( otherwise a single tensor is returned. """ + # pyre-fixme[2]: Parameter must be annotated. def layer_forward_func(*args) -> Tensor: layer_length = args[-1] layer_input = args[:layer_length] @@ -172,6 +177,9 @@ def layer_forward_func(*args) -> Tensor: else: all_layer_inputs[layer_input[0].device] = layer_input + # pyre-fixme[53]: Captured variable `all_layer_inputs` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_hook(module, inp, out=None): device = _extract_device(module, inp, out) is_layer_tuple = ( @@ -237,5 +245,6 @@ def forward_hook(module, inp, out=None): return _attr @property + # pyre-fixme[3]: Return type must be annotated. def attributor(self): return FeaturePermutation diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py index 18c9aa176c..e02dd4cf63 100644 --- a/captum/attr/_core/layer/layer_gradient_shap.py +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import typing from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union @@ -59,6 +61,7 @@ class LayerGradientShap(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -101,28 +104,40 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `106`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `120`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, target: TargetType = None, additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @@ -131,6 +146,7 @@ def attribute( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. baselines: Union[TensorOrTupleOfTensorsGeneric, Callable], n_samples: int = 5, stdevs: Union[float, Tuple[float, ...]] = 0.0, @@ -284,7 +300,12 @@ def attribute( """ # since `baselines` is a distribution, we can generate it using a function # rather than passing it as an input argument + # pyre-fixme[9]: baselines has type `Union[typing.Callable[..., typing.Any], + # Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, typing.Tuple[Tensor, + # ...]]]]`; used as `Tuple[Tensor, ...]`. baselines = _format_callable_baseline(baselines, inputs) + # pyre-fixme[16]: Item `Callable` of `Union[(...) -> Any, + # TensorOrTupleOfTensorsGeneric]` has no attribute `__getitem__`. assert isinstance(baselines[0], torch.Tensor), ( "Baselines distribution has to be provided in a form " "of a torch.Tensor {}.".format(baselines[0]) @@ -319,6 +340,7 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -326,6 +348,7 @@ def multiplies_by_inputs(self): class LayerInputBaselineXGradient(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -368,18 +391,26 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `373`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `385`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -387,6 +418,8 @@ def attribute( target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, @@ -459,11 +492,15 @@ def attribute( # type: ignore return _compute_conv_delta_and_format_attrs( self, return_convergence_delta, + # pyre-fixme[6]: For 3rd argument expected `Tuple[Tensor, ...]` but got + # `Union[List[typing.Tuple[Tensor, ...]], tuple[Tensor]]`. attributions, baselines, inputs, additional_forward_args, target, + # pyre-fixme[31]: Expression `Literal[False])]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. cast(Union[Literal[True], Literal[False]], len(attributions) > 1), ) @@ -471,5 +508,6 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_gradient_x_activation.py b/captum/attr/_core/layer/layer_gradient_x_activation.py index 6f055fb2bf..f8683ad41a 100644 --- a/captum/attr/_core/layer/layer_gradient_x_activation.py +++ b/captum/attr/_core/layer/layer_gradient_x_activation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, Dict, List, Optional, Tuple, Union from captum._utils.common import ( @@ -22,6 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, @@ -66,6 +69,7 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -74,6 +78,7 @@ def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, @@ -183,6 +188,10 @@ def attribute( if isinstance(self.layer, Module): return _format_output( len(layer_evals) > 1, + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but + # got `List[typing.Tuple[Tensor, ...]]`. + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but + # got `List[typing.Tuple[Tensor, ...]]`. self.multiply_gradient_acts(layer_gradients, layer_evals), ) else: diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 2e8897b2ae..fb63df0ffb 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import functools import warnings from typing import Any, Callable, cast, List, overload, Tuple, Union @@ -45,6 +47,7 @@ class LayerIntegratedGradients(LayerAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, @@ -107,20 +110,27 @@ def __init__( ) @overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `112`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType, target: TargetType, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, n_steps: int, method: str, internal_batch_size: Union[None, int], + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False], attribute_to_layer_input: bool, ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... @overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `126`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -130,6 +140,8 @@ def attribute( n_steps: int, method: str, internal_batch_size: Union[None, int], + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool, ) -> Tuple[ @@ -138,6 +150,8 @@ def attribute( ]: ... @overload + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -158,6 +172,8 @@ def attribute( ]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -358,6 +374,8 @@ def attribute( additional_forward_args ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def flatten_tuple(tup): return tuple( sum((list(x) if isinstance(x, (tuple, list)) else [x] for x in tup), []) @@ -397,9 +415,11 @@ def flatten_tuple(tup): # inputs -> these inputs are scaled def gradient_func( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> Tuple[Tensor, ...]: if self.device_ids is None or len(self.device_ids) == 0: @@ -408,7 +428,10 @@ def gradient_func( # scatter method does not have a precise enough return type in its # stub, so suppress the type warning. scattered_inputs = scatter( # type:ignore - inputs, target_gpus=self.device_ids + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, typing.Tuple[Tensor, ...]]`. + inputs, + target_gpus=self.device_ids, ) scattered_inputs_dict = { @@ -418,8 +441,20 @@ def gradient_func( with torch.autograd.set_grad_enabled(True): + # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not + # annotated. + # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not + # annotated. + # pyre-fixme[3]: Return type must be annotated. def layer_forward_hook( - module, hook_inputs, hook_outputs=None, layer_idx=0 + # pyre-fixme[2]: Parameter must be annotated. + module, + # pyre-fixme[2]: Parameter must be annotated. + hook_inputs, + # pyre-fixme[2]: Parameter must be annotated. + hook_outputs=None, + # pyre-fixme[2]: Parameter must be annotated. + layer_idx=0, ): device = _extract_device(module, hook_inputs, hook_outputs) is_layer_tuple = ( @@ -534,5 +569,6 @@ def has_convergence_delta(self) -> bool: return True @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self.ig.multiplies_by_inputs diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index 20c0f19aba..7bd2721328 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from typing import Any, cast, List, Tuple, Union @@ -58,26 +60,37 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: LayerAttribution.__init__(self, model, layer) LRP.__init__(self, model) if hasattr(self.model, "device_ids"): + # pyre-fixme[4]: Attribute must be annotated. self.device_ids = cast(List[int], self.model.device_ids) @typing.overload # type: ignore + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `66`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, verbose: bool = False, ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `77`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, verbose: bool = False, @@ -206,22 +219,36 @@ def attribute( >>> attribution = layer_lrp.attribute(input, target=5) """ + # pyre-fixme[16]: `LayerLRP` has no attribute `verbose`. self.verbose = verbose + # pyre-fixme[16]: `LayerLRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() + # pyre-fixme[16]: `LayerLRP` has no attribute `layers`. self.layers = [] self._get_layers(self.model) self._check_and_attach_rules() + # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`. self.attribute_to_layer_input = attribute_to_layer_input + # pyre-fixme[16]: `LayerLRP` has no attribute `backward_handles`. self.backward_handles = [] + # pyre-fixme[16]: `LayerLRP` has no attribute `forward_handles`. self.forward_handles = [] + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) try: # 1. Forward pass output = self._compute_output_and_change_weights( - inputs, target, additional_forward_args + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but + # got `TensorOrTupleOfTensorsGeneric`. + inputs, + target, + additional_forward_args, ) self._register_forward_hooks() # 2. Forward pass + backward pass @@ -231,6 +258,8 @@ def attribute( relevances = self._get_output_relevance(output) finally: self._restore_model() + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) if return_convergence_delta: @@ -249,7 +278,10 @@ def attribute( else: return relevances # type: ignore + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_single_output_relevance(self, layer, output): + # pyre-fixme[16]: `LayerLRP` has no attribute `attribute_to_layer_input`. if self.attribute_to_layer_input: normalized_relevances = layer.rule.relevance_input else: @@ -270,6 +302,8 @@ def _get_single_output_relevance(self, layer, output): (-1,) + (1,) * (normalized_relevances.dim() - 1) ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_output_relevance(self, output): if isinstance(self.layer, list): relevances = [] @@ -280,7 +314,9 @@ def _get_output_relevance(self, output): return self._get_single_output_relevance(self.layer, output) @staticmethod + # pyre-fixme[3]: Return annotation cannot contain `Any`. def _convert_list_to_tuple( + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. relevances: Union[List[Any], Tuple[Any, ...]] ) -> Tuple[Any, ...]: if isinstance(relevances, list): diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 61dd9b33ee..cfe120774d 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import inspect import math import typing @@ -70,12 +72,17 @@ class LimeBase(PerturbationAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, interpretable_model: Model, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. similarity_func: Callable, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. perturb_func: Callable, perturb_interpretable_space: bool, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. from_interp_rep_transform: Optional[Callable], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. to_interp_rep_transform: Optional[Callable], ) -> None: r""" @@ -240,10 +247,12 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_samples: int = 50, perturbations_per_eval: int = 1, show_progress: bool = False, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ) -> Tensor: r""" @@ -530,15 +539,18 @@ def attribute( self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) return self.interpretable_model.representation() + # pyre-fixme[3]: Return type must be annotated. def _evaluate_batch( self, curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], expanded_target: TargetType, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. expanded_additional_args: Any, device: torch.device, ): model_out = _run_forward( self.forward_func, + # pyre-fixme[6]: For 1st argument expected `Sequence[Variable[TupleOrTens... _reduce_list(curr_model_inputs), expanded_target, expanded_additional_args, @@ -556,6 +568,7 @@ def has_convergence_delta(self) -> bool: return False @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return False @@ -564,6 +577,8 @@ def multiplies_by_inputs(self): # for Lime child implementation. +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): assert ( "feature_mask" in kwargs @@ -590,7 +605,9 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs): def get_exp_kernel_similarity_function( - distance_mode: str = "cosine", kernel_width: float = 1.0 + distance_mode: str = "cosine", + kernel_width: float = 1.0, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. ) -> Callable: r""" This method constructs an appropriate similarity function to compute @@ -623,6 +640,8 @@ def get_exp_kernel_similarity_function( similarity_fn for Lime or LimeBase. """ + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float() flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float() @@ -638,6 +657,8 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): return default_exp_kernel +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def default_perturb_func(original_inp, **kwargs): assert ( "num_interp_features" in kwargs @@ -651,6 +672,8 @@ def default_perturb_func(original_inp, **kwargs): return torch.bernoulli(probs).to(device=device).long() +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def construct_feature_mask(feature_mask, formatted_inputs): if feature_mask is None: feature_mask, num_interp_features = _construct_default_feature_mask( @@ -661,6 +684,7 @@ def construct_feature_mask(feature_mask, formatted_inputs): min_interp_features = int( min( torch.min(single_mask).item() + # pyre-fixme[16]: `None` has no attribute `__iter__`. for single_mask in feature_mask if single_mask.numel() ) @@ -674,6 +698,8 @@ def construct_feature_mask(feature_mask, formatted_inputs): single_mask - min_interp_features for single_mask in feature_mask ) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `Optional[typing.Tuple[typing.Any, ...]]`. num_interp_features = _get_max_feature_index(feature_mask) + 1 return feature_mask, num_interp_features @@ -716,9 +742,12 @@ class Lime(LimeBase): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, interpretable_model: Optional[Model] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. similarity_func: Optional[Callable] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. perturb_func: Optional[Callable] = None, ) -> None: r""" @@ -834,6 +863,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, @@ -1075,14 +1105,18 @@ def _attribute_kwargs( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, n_samples: int = 25, perturbations_per_eval: int = 1, return_input_shape: bool = True, show_progress: bool = False, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ) -> TensorOrTupleOfTensorsGeneric: + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) bsz = formatted_inputs[0].shape[0] @@ -1185,6 +1219,8 @@ def _attribute_kwargs( # type: ignore **kwargs, ) if return_input_shape: + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return self._convert_output_shape( formatted_inputs, feature_mask, @@ -1196,22 +1232,30 @@ def _attribute_kwargs( # type: ignore return coefs @typing.overload + # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept + # all possible arguments of overload defined on line `1201`. def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], ) -> Tuple[Tensor, ...]: ... @typing.overload + # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept + # all possible arguments of overload defined on line `1211`. def _convert_output_shape( self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], coefs: Tensor, num_interp_features: int, + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False], ) -> Tensor: ... diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index d04b640fe8..194a910765 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -1,3 +1,4 @@ +# pyre-strict from copy import copy from typing import Callable, cast, Dict, List, Optional, Union @@ -30,6 +31,7 @@ class LLMAttributionResult: It also provides utilities to help present and plot the result in different forms. """ + # pyre-fixme[3]: Return type must be annotated. def __init__( self, seq_attr: Tensor, @@ -43,9 +45,12 @@ def __init__( self.output_tokens = output_tokens @property + # pyre-fixme[3]: Return type must be annotated. def seq_attr_dict(self): return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)} + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def plot_token_attr(self, show=False): """ Generate a matplotlib plot for visualising the attribution @@ -56,6 +61,7 @@ def plot_token_attr(self, show=False): Default: False """ + # pyre-fixme[16]: `Optional` has no attribute `cpu`. token_attr = self.token_attr.cpu() # maximum absolute attribution value @@ -113,6 +119,8 @@ def plot_token_attr(self, show=False): else: return fig, ax + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def plot_seq_attr(self, show=False): """ Generate a matplotlib plot for visualising the attribution @@ -175,9 +183,11 @@ class LLMAttribution(Attribution): ) SUPPORTED_INPUTS = (TextTemplateInput, TextTokenInput) + # pyre-fixme[3]: Return type must be annotated. def __init__( self, attr_method: Attribution, + # pyre-fixme[2]: Parameter must be annotated. tokenizer, attr_target: str = "log_prob", # TODO: support callable attr_target ): @@ -203,7 +213,9 @@ class created with the llm model that follows huggingface style super().__init__(attr_method.forward_func) # shallow copy is enough to avoid modifying original instance + # pyre-fixme[4]: Attribute must be annotated. self.attr_method = copy(attr_method) + # pyre-fixme[4]: Attribute must be annotated. self.include_per_token_attr = isinstance( attr_method, self.SUPPORTED_PER_TOKEN_ATTR_METHODS ) @@ -212,9 +224,12 @@ class created with the llm model that follows huggingface style # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. + # pyre-fixme[4]: Attribute must be annotated. self.model = cast(nn.Module, self.forward_func) + # pyre-fixme[4]: Attribute must be annotated. self.tokenizer = tokenizer + # pyre-fixme[4]: Attribute must be annotated. self.device = ( cast(torch.device, self.model.device) if hasattr(self.model, "device") @@ -227,11 +242,16 @@ class created with the llm model that follows huggingface style ), "attr_target should be either 'log_prob' or 'prob'" self.attr_target = attr_target + # pyre-fixme[3]: Return type must be annotated. def _forward_func( self, + # pyre-fixme[2]: Parameter must be annotated. perturbed_tensor, + # pyre-fixme[2]: Parameter must be annotated. inp, + # pyre-fixme[2]: Parameter must be annotated. target_tokens, + # pyre-fixme[2]: Parameter must be annotated. use_cached_outputs=False, _inspect_forward=None, ): @@ -275,6 +295,8 @@ def _forward_func( ).unsqueeze(0) else: target_log_probs = total_log_prob + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int, + # Tensor]`. target_probs = torch.exp(target_log_probs) if _inspect_forward: @@ -286,6 +308,7 @@ def _forward_func( return target_probs if self.attr_target != "log_prob" else target_log_probs + # pyre-fixme[3]: Return type must be annotated. def _format_model_input(self, model_input: Union[str, Tensor]): """ Convert str to tokenized tensor @@ -304,10 +327,15 @@ def attribute( inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, num_trials: int = 1, + # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use + # `typing.Dict[, ]` to avoid runtime subscripting + # errors. gen_args: Optional[Dict] = None, use_cached_outputs: bool = True, # internal callback hook can be used for logging + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. _inspect_forward: Optional[Callable] = None, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ) -> LLMAttributionResult: """ @@ -362,6 +390,7 @@ def attribute( attr = torch.zeros( [ + # pyre-fixme[61]: `target_tokens` is undefined, or not always defined. 1 + len(target_tokens) if self.include_per_token_attr else 1, inp.n_itp_features, ], @@ -376,6 +405,8 @@ def attribute( attr_input, additional_forward_args=( inp, + # pyre-fixme[61]: `target_tokens` is undefined, or not always + # defined. target_tokens, use_cached_outputs, _inspect_forward, @@ -400,6 +431,7 @@ def attribute( attr[1:] if self.include_per_token_attr else None ), # shape(n_output_token, n_input_features) inp.values, + # pyre-fixme[61]: `target_tokens` is undefined, or not always defined. self.tokenizer.convert_ids_to_tokens(target_tokens), ) @@ -420,9 +452,12 @@ class LLMGradientAttribution(Attribution): SUPPORTED_METHODS = (LayerIntegratedGradients,) SUPPORTED_INPUTS = (TextTokenInput,) + # pyre-fixme[3]: Return type must be annotated. def __init__( self, + # pyre-fixme[2]: Parameter must be annotated. attr_method, + # pyre-fixme[2]: Parameter must be annotated. tokenizer, ): """ @@ -439,20 +474,25 @@ class created with the llm model that follows huggingface style super().__init__(attr_method.forward_func) # shallow copy is enough to avoid modifying original instance + # pyre-fixme[4]: Attribute must be annotated. self.attr_method = copy(attr_method) self.attr_method.forward_func = self._forward_func # alias, we really need a model and don't support wrapper functions # coz we need call model.forward, model.generate, etc. + # pyre-fixme[4]: Attribute must be annotated. self.model = cast(nn.Module, self.forward_func) + # pyre-fixme[4]: Attribute must be annotated. self.tokenizer = tokenizer + # pyre-fixme[4]: Attribute must be annotated. self.device = ( cast(torch.device, self.model.device) if hasattr(self.model, "device") else next(self.model.parameters()).device ) + # pyre-fixme[3]: Return type must be annotated. def _forward_func( self, perturbed_tensor: Tensor, @@ -485,17 +525,24 @@ def _forward_func( # the attribution target is limited to the log probability return token_log_probs + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _format_model_input(self, model_input): """ Convert str to tokenized tensor """ return model_input.to(self.device) + # pyre-fixme[3]: Return type must be annotated. def attribute( self, inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, + # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use + # `typing.Dict[, ]` to avoid runtime subscripting + # errors. gen_args: Optional[Dict] = None, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ): """ @@ -547,12 +594,15 @@ def attribute( attr_inp = inp.to_tensor().to(self.device) attr_list = [] + # pyre-fixme[61]: `target_tokens` is undefined, or not always defined. for cur_target_idx, _ in enumerate(target_tokens): # attr in shape(batch_size, input+output_len, emb_dim) attr = self.attr_method.attribute( attr_inp, additional_forward_args=( inp, + # pyre-fixme[61]: `target_tokens` is undefined, or not always + # defined. target_tokens, cur_target_idx, ), @@ -592,5 +642,6 @@ def attribute( seq_attr, attr, # shape(n_output_token, n_input_features) inp.values, + # pyre-fixme[61]: `target_tokens` is undefined, or not always defined. self.tokenizer.convert_ids_to_tokens(target_tokens), ) diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 70235325d6..bec606863f 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import typing from collections import defaultdict from typing import Any, cast, List, Tuple, Union @@ -60,27 +62,39 @@ def multiplies_by_inputs(self) -> bool: return True @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `65`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[False] = False, verbose: bool = False, ) -> TensorOrTupleOfTensorsGeneric: ... @typing.overload + # pyre-fixme[43]: The implementation of `attribute` does not accept all possible + # arguments of overload defined on line `75`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, additional_forward_args: Any = None, *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. return_convergence_delta: Literal[True], verbose: bool = False, ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @log_usage() + # pyre-fixme[43]: This definition does not have the same decorators as the + # preceding overload(s). def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -184,22 +198,37 @@ def attribute( >>> attribution = lrp.attribute(input, target=5) """ + # pyre-fixme[16]: `LRP` has no attribute `verbose`. self.verbose = verbose + # pyre-fixme[16]: `LRP` has no attribute `_original_state_dict`. self._original_state_dict = self.model.state_dict() + # pyre-fixme[16]: `LRP` has no attribute `layers`. self.layers: List[Module] = [] self._get_layers(self.model) self._check_and_attach_rules() + # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles: List[RemovableHandle] = [] + # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles: List[RemovableHandle] = [] + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) try: # 1. Forward pass: Change weights of layers according to selected rules. output = self._compute_output_and_change_weights( - inputs, target, additional_forward_args + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but + # got `TensorOrTupleOfTensorsGeneric`. + inputs, + target, + additional_forward_args, ) # 2. Forward pass + backward pass: Register hooks to configure relevance # propagation and execute back-propagation. @@ -215,9 +244,12 @@ def attribute( finally: self._restore_model() + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) if return_convergence_delta: + # pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen... return ( _format_output(is_inputs_tuple, relevances), self.compute_convergence_delta(relevances, output), @@ -270,11 +302,13 @@ def compute_convergence_delta( def _get_layers(self, model: Module) -> None: for layer in model.children(): if len(list(layer.children())) == 0: + # pyre-fixme[16]: `LRP` has no attribute `layers`. self.layers.append(layer) else: self._get_layers(layer) def _check_and_attach_rules(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): layer.activations = {} # type: ignore @@ -313,40 +347,49 @@ def _check_rules(self) -> None: ) def _register_forward_hooks(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if type(layer) in SUPPORTED_NON_LINEAR_LAYERS: backward_handles = _register_backward_hook( layer, PropagationRule.backward_hook_activation, self ) + # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. self.backward_handles.extend(backward_handles) else: forward_handle = layer.register_forward_hook( layer.rule.forward_hook # type: ignore ) + # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) + # pyre-fixme[16]: `LRP` has no attribute `verbose`. if self.verbose: print(f"Applied {layer.rule} on layer {layer}") def _register_weight_hooks(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_hook( layer.rule.forward_hook_weights # type: ignore ) + # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _register_pre_hooks(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if layer.rule is not None: forward_handle = layer.register_forward_pre_hook( layer.rule.forward_pre_hook_activations # type: ignore ) + # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. self.forward_handles.append(forward_handle) def _compute_output_and_change_weights( self, inputs: Tuple[Tensor, ...], target: TargetType, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, ) -> Tensor: try: @@ -365,12 +408,15 @@ def _compute_output_and_change_weights( return cast(Tensor, output) def _remove_forward_hooks(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `forward_handles`. for forward_handle in self.forward_handles: forward_handle.remove() def _remove_backward_hooks(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `backward_handles`. for backward_handle in self.backward_handles: backward_handle.remove() + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer.rule, "_handle_input_hooks"): for handle in layer.rule._handle_input_hooks: # type: ignore @@ -379,11 +425,13 @@ def _remove_backward_hooks(self) -> None: layer.rule._handle_output_hook.remove() # type: ignore def _remove_rules(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "rule"): del layer.rule def _clear_properties(self) -> None: + # pyre-fixme[16]: `LRP` has no attribute `layers`. for layer in self.layers: if hasattr(layer, "activation"): del layer.activation diff --git a/captum/attr/_core/neuron/neuron_conductance.py b/captum/attr/_core/neuron/neuron_conductance.py index 844df3a28b..34e66c2223 100644 --- a/captum/attr/_core/neuron/neuron_conductance.py +++ b/captum/attr/_core/neuron/neuron_conductance.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -37,6 +39,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -91,9 +94,11 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[int, ...], Callable], baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "riemann_trapezoid", @@ -280,9 +285,15 @@ def attribute( " or performing other operations on the tensor may lead to inaccurate" " results." ) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs, baselines = _format_input_baseline(inputs, baselines) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. _validate_input(inputs, baselines, n_steps, method) num_examples = inputs[0].shape[0] @@ -304,6 +315,8 @@ def attribute( ) else: attrs = self._attribute( + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but + # got `TensorOrTupleOfTensorsGeneric`. inputs=inputs, neuron_selector=neuron_selector, baselines=baselines, @@ -314,14 +327,18 @@ def attribute( attribute_to_neuron_input=attribute_to_neuron_input, grad_kwargs=grad_kwargs, ) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attrs) def _attribute( self, inputs: Tuple[Tensor, ...], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[int, ...], Callable], baselines: Tuple[Union[Tensor, int, float], ...], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "riemann_trapezoid", @@ -393,6 +410,7 @@ def _attribute( # Aggregates across all steps for each tensor in the input tuple total_grads = tuple( + # pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`. _reshape_and_sum(scaled_grad, n_steps, num_examples, input_grad.shape[1:]) for (scaled_grad, input_grad) in zip(scaled_grads, input_grads) ) @@ -410,5 +428,6 @@ def _attribute( return attributions @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_deep_lift.py b/captum/attr/_core/neuron/neuron_deep_lift.py index a631dabf60..a89c1a0aab 100644 --- a/captum/attr/_core/neuron/neuron_deep_lift.py +++ b/captum/attr/_core/neuron/neuron_deep_lift.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, cast, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -77,8 +79,10 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], baselines: BaselineType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -243,6 +247,7 @@ def attribute( ) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -306,10 +311,12 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], baselines: Union[ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, @@ -468,5 +475,6 @@ def attribute( ) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_feature_ablation.py b/captum/attr/_core/neuron/neuron_feature_ablation.py index 5e4c7eaafd..c931015b5f 100644 --- a/captum/attr/_core/neuron/neuron_feature_ablation.py +++ b/captum/attr/_core/neuron/neuron_feature_ablation.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union import torch @@ -28,6 +30,7 @@ class NeuronFeatureAblation(NeuronAttribution, PerturbationAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -57,8 +60,10 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], baselines: BaselineType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, attribute_to_neuron_input: bool = False, @@ -245,6 +250,7 @@ def attribute( >>> feature_mask=feature_mask) """ + # pyre-fixme[3]: Return type must be annotated. def neuron_forward_func(*args: Any): with torch.no_grad(): layer_eval = _forward_layer_eval( diff --git a/captum/attr/_core/neuron/neuron_gradient.py b/captum/attr/_core/neuron/neuron_gradient.py index 0ce184e30a..fef1f2c190 100644 --- a/captum/attr/_core/neuron/neuron_gradient.py +++ b/captum/attr/_core/neuron/neuron_gradient.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union from captum._utils.common import ( @@ -26,6 +28,7 @@ class NeuronGradient(NeuronAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -57,7 +60,9 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -158,11 +163,17 @@ def attribute( >>> # index (4,1,2). >>> attribution = neuron_ig.attribute(input, (4,1,2)) """ + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) additional_forward_args = _format_additional_forward_args( additional_forward_args ) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) _, input_grads = _forward_layer_eval_with_neuron_grads( @@ -175,5 +186,9 @@ def attribute( attribute_to_layer_input=attribute_to_neuron_input, ) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, input_grads) diff --git a/captum/attr/_core/neuron/neuron_gradient_shap.py b/captum/attr/_core/neuron/neuron_gradient_shap.py index b9b5d73c99..816ecb69e3 100644 --- a/captum/attr/_core/neuron/neuron_gradient_shap.py +++ b/captum/attr/_core/neuron/neuron_gradient_shap.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -48,6 +50,7 @@ class NeuronGradientShap(NeuronAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -94,12 +97,14 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], baselines: Union[ TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], n_samples: int = 5, stdevs: float = 0.0, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -253,5 +258,6 @@ def attribute( ) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py index 84dca50106..d95edf2e37 100644 --- a/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py +++ b/captum/attr/_core/neuron/neuron_guided_backprop_deconvnet.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -58,7 +60,9 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: @@ -212,7 +216,9 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, attribute_to_neuron_input: bool = False, ) -> TensorOrTupleOfTensorsGeneric: diff --git a/captum/attr/_core/neuron/neuron_integrated_gradients.py b/captum/attr/_core/neuron/neuron_integrated_gradients.py index cb70cebe09..aebf06abce 100644 --- a/captum/attr/_core/neuron/neuron_integrated_gradients.py +++ b/captum/attr/_core/neuron/neuron_integrated_gradients.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Tuple, Union from captum._utils.gradient import construct_neuron_grad_fn @@ -25,6 +27,7 @@ class NeuronIntegratedGradients(NeuronAttribution, GradientAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -73,8 +76,10 @@ def __init__( def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable], baselines: Union[None, Tensor, Tuple[Tensor, ...]] = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", @@ -248,5 +253,6 @@ def attribute( ) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index d16cec47a1..f030cb9e2a 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from enum import Enum from typing import Any, cast, List, Optional, Tuple, Union @@ -25,6 +27,7 @@ class NoiseTunnelType(Enum): vargrad = 3 +# pyre-fixme[5]: Global expression must be annotated. SUPPORTED_NOISE_TUNNEL_TYPES = list(NoiseTunnelType.__members__.keys()) @@ -63,14 +66,18 @@ def __init__(self, attribution_method: Attribution) -> None: Conductance or Saliency. """ self.attribution_method = attribution_method + # pyre-fixme[4]: Attribute must be annotated. self.is_delta_supported = self.attribution_method.has_convergence_delta() + # pyre-fixme[4]: Attribute must be annotated. self._multiply_by_inputs = self.attribution_method.multiplies_by_inputs + # pyre-fixme[4]: Attribute must be annotated. self.is_gradient_method = isinstance( self.attribution_method, GradientAttribution ) Attribution.__init__(self, self.attribution_method.forward_func) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return self._multiply_by_inputs @@ -195,6 +202,7 @@ def add_noise_to_inputs(nt_samples_partition: int) -> Tuple[Tensor, ...]: if self.is_gradient_method else add_noise_to_input(input, stdev, nt_samples_partition) ) + # pyre-fixme[61]: `stdevs_` is undefined, or not always defined. for (input, stdev) in zip(inputs, stdevs_) ) @@ -205,6 +213,8 @@ def add_noise_to_input( bsz = input.shape[0] # expand input size by the number of drawn samples + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` + # and `Size`. input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] # expand stdev for the shape of the input and number of drawn samples @@ -231,10 +241,13 @@ def update_sum_attribution_and_sq( Tuple[int, ...], (bsz, nt_samples_batch_size_inter) ) if len(attribution.shape) > 1: + # pyre-fixme[22]: The cast is redundant. attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) attribution = attribution.view(attribution_shape) current_attribution_sum = attribution.sum(dim=1, keepdim=False) + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) sum_attribution[i] = ( @@ -249,7 +262,9 @@ def update_sum_attribution_and_sq( ) def compute_partial_attribution( - inputs_with_noise_partition: Tuple[Tensor, ...], kwargs_partition: Any + inputs_with_noise_partition: Tuple[Tensor, ...], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + kwargs_partition: Any, ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: # smoothgrad_Attr(x) = 1 / n * sum(Attr(x + N(0, sigma^2)) # NOTE: using __wrapped__ such that it does not log the inner logs @@ -277,6 +292,9 @@ def compute_partial_attribution( delta, ) + # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use + # `typing.Dict[, ]` to avoid runtime subscripting + # errors. def expand_partial(nt_samples_partition: int, kwargs_partial: dict) -> None: # if the algorithm supports targets, baselines and/or # additional_forward_args they will be expanded based @@ -311,6 +329,7 @@ def compute_smoothing( ) ) + # pyre-fixme[22]: The cast is redundant. return cast(Tuple[Tensor, ...], vargrad) def update_partial_attribution_and_delta( @@ -371,10 +390,15 @@ def update_partial_attribution_and_delta( ) = compute_partial_attribution(inputs_with_noise, kwargs_copy) if len(sum_attributions) == 0: + # pyre-fixme[9]: sum_attributions has type + # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions = [None] * len(attributions_partial) + # pyre-fixme[9]: sum_attributions_sq has type + # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions_sq = [None] * len(attributions_partial) update_partial_attribution_and_delta( + # pyre-fixme[22]: The cast is redundant. cast(Tuple[Tensor, ...], attributions_partial), cast(Tensor, delta_partial), cast(List[Tensor], sum_attributions), @@ -396,6 +420,7 @@ def update_partial_attribution_and_delta( ) = compute_partial_attribution(inputs_with_noise, kwargs) update_partial_attribution_and_delta( + # pyre-fixme[22]: The cast is redundant. cast(Tuple[Tensor, ...], attributions_partial), cast(Tensor, delta_partial), cast(List[Tensor], sum_attributions), @@ -417,7 +442,9 @@ def update_partial_attribution_and_delta( ] ) attributions = compute_smoothing( + # pyre-fixme[22]: The cast is redundant. cast(Tuple[Tensor, ...], expected_attributions), + # pyre-fixme[22]: The cast is redundant. cast(Tuple[Tensor, ...], expected_attributions_sq), ) @@ -426,7 +453,11 @@ def update_partial_attribution_and_delta( delta = torch.cat(delta_partial_list, dim=0) return self._apply_checks_and_return_attributions( - attributions, is_attrib_tuple, return_convergence_delta, delta + attributions, + # pyre-fixme[61]: `is_attrib_tuple` is undefined, or not always defined. + is_attrib_tuple, + return_convergence_delta, + delta, ) def _apply_checks_and_return_attributions( @@ -435,9 +466,14 @@ def _apply_checks_and_return_attributions( is_attrib_tuple: bool, return_convergence_delta: bool, delta: Union[None, Tensor], + # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: + # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` + # isn't present in the function's parameters. ) -> Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor] ]: + # pyre-fixme[9]: Unable to unpack `Union[Tensor, typing.Tuple[Tensor, + # ...]]`, expected a tuple. attributions = _format_output(is_attrib_tuple, attributions) ret = ( @@ -446,6 +482,9 @@ def _apply_checks_and_return_attributions( else attributions ) ret = cast( + # pyre-fixme[34]: `Variable[TensorOrTupleOfTensorsGeneric <: + # [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` + # isn't present in the function's parameters. Union[ TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor], diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 6ca1355944..188794cf1f 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, Tuple, Union import numpy as np @@ -35,6 +37,7 @@ class Occlusion(FeatureAblation): /tensorflow/methods.py#L401 """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -55,6 +58,7 @@ def attribute( # type: ignore ] = None, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, perturbations_per_eval: int = 1, show_progress: bool = False, @@ -366,6 +370,8 @@ def _occlusion_mask( padded_tensor = torch.nn.functional.pad( sliding_window_tsr, tuple(pad_values) # type: ignore ) + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and + # `Size`. return padded_tensor.reshape((1,) + padded_tensor.shape) def _get_feature_range_and_mask( @@ -374,6 +380,8 @@ def _get_feature_range_and_mask( feature_max = np.prod(kwargs["shift_counts"]) return 0, feature_max, None + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_feature_counts(self, inputs, feature_mask, **kwargs): """return the numbers of possible input features""" return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py index f0afead501..3c53000366 100644 --- a/captum/attr/_core/saliency.py +++ b/captum/attr/_core/saliency.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from typing import Any, Callable import torch @@ -23,6 +25,7 @@ class Saliency(GradientAttribution): https://arxiv.org/abs/1312.6034 """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -38,6 +41,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, abs: bool = True, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -120,9 +124,15 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. gradient_mask = apply_gradient_requirements(inputs) # No need to format additional_forward_args here. @@ -134,5 +144,9 @@ def attribute( attributions = tuple(torch.abs(gradient) for gradient in gradients) else: attributions = gradients + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. undo_gradient_requirements(inputs, gradient_mask) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index da6741d656..c137ee4356 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import itertools import math import warnings @@ -54,6 +56,8 @@ def _shape_feature_mask( f"input shape: {inp.shape}, feature mask shape {mask.shape}" ) if mask.dim() < inp.dim(): + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int, + # ...]` and `Size`. mask = mask.reshape((1,) * (inp.dim() - mask.dim()) + mask.shape) mask_list.append(mask) @@ -85,6 +89,7 @@ class ShapleyValueSampling(PerturbationAttribution): https://pdfs.semanticscholar.org/7715/bb1070691455d1fcfc6346ff458dbca77b2c.pdf """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -106,6 +111,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, n_samples: int = 25, @@ -295,12 +301,29 @@ def attribute( """ # Keeps track whether original input is a tuple or not before # converting it into a tuple. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs, baselines = _format_input_baseline(inputs, baselines) additional_forward_args = _format_additional_forward_args( additional_forward_args ) + # pyre-fixme[9]: feature_mask has type + # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. feature_mask = _format_feature_mask(feature_mask, inputs) + # pyre-fixme[9]: feature_mask has type + # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`; used as `Tuple[Tensor, ...]`. + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`. + # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. feature_mask = _shape_feature_mask(feature_mask, inputs) assert ( @@ -308,9 +331,14 @@ def attribute( ), "Ablations per evaluation must be at least 1." with torch.no_grad(): + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `TensorOrTupleOfTensorsGeneric`. baselines = _tensorize_baseline(inputs, baselines) num_examples = inputs[0].shape[0] + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got + # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`. total_features = _get_max_feature_index(feature_mask) + 1 if show_progress: @@ -365,10 +393,16 @@ def attribute( current_target, current_masks, ) in self._perturbation_generator( + # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` + # but got `TensorOrTupleOfTensorsGeneric`. inputs, additional_forward_args, target, baselines, + # pyre-fixme[6]: For 5th argument expected + # `TensorOrTupleOfTensorsGeneric` but got + # `Optional[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor, + # typing.Tuple[Tensor, ...]]]]`. feature_mask, feature_permutation, perturbations_per_eval, @@ -409,7 +443,11 @@ def attribute( # Append n_input_feature dim of 1 to make the tensor # have the same dim as the mask tensor. formatted_eval_diff = eval_diff.reshape( - (-1,) + output_shape + (len(inputs[j].shape) - 1) * (1,) + (-1,) + # pyre-fixme[58]: `+` is not supported for operand types + # `Tuple[int]` and `Size`. + + output_shape + + (len(inputs[j].shape) - 1) * (1,) ) # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) @@ -423,6 +461,8 @@ def attribute( cur_mask = cur_mask.reshape( cur_mask.shape[:2] + (len(output_shape) - 1) * (1,) + # pyre-fixme[58]: `+` is not supported for operand types + # `Tuple[int, ...]` and `Size`. + cur_mask.shape[2:] ) @@ -441,11 +481,15 @@ def attribute( tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib ) formatted_attr = _format_output(is_inputs_tuple, attrib) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. return formatted_attr + # pyre-fixme[3]: Return annotation cannot contain `Any`. def _perturbation_generator( self, inputs: Tuple[Tensor, ...], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_args: Any, target: TargetType, baselines: Tuple[Tensor, ...], @@ -524,10 +568,13 @@ def _perturbation_generator( combined_masks, ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * n_samples + # pyre-fixme[2]: Parameter must be annotated. def _strict_run_forward(self, *args, **kwargs) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output @@ -585,6 +632,7 @@ class ShapleyValues(ShapleyValueSampling): evaluations, and we plan to add this approach in the future. """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -606,6 +654,7 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, @@ -814,6 +863,8 @@ def attribute( show_progress=show_progress, ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * math.factorial( diff --git a/captum/attr/_models/base.py b/captum/attr/_models/base.py index 0b9e406d73..cb4964f514 100644 --- a/captum/attr/_models/base.py +++ b/captum/attr/_models/base.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import warnings from functools import reduce @@ -19,14 +21,21 @@ class InterpretableEmbeddingBase(Module): precomputed embedding vectors to the layers below. """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, embedding, full_name) -> None: Module.__init__(self) + # pyre-fixme[4]: Attribute must be annotated. self.num_embeddings = getattr(embedding, "num_embeddings", None) + # pyre-fixme[4]: Attribute must be annotated. self.embedding_dim = getattr(embedding, "embedding_dim", None) + # pyre-fixme[4]: Attribute must be annotated. self.embedding = embedding + # pyre-fixme[4]: Attribute must be annotated. self.full_name = full_name + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, *inputs, **kwargs): r""" The forward function of a wrapper embedding layer that takes and returns @@ -70,6 +79,8 @@ def forward(self, *inputs, **kwargs): ) return inputs[0] if len(inputs) > 0 else list(kwargs.values())[0] + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def indices_to_embeddings(self, *input, **kwargs): r""" Maps indices to corresponding embedding vectors. E.g. word embeddings @@ -102,6 +113,7 @@ class TokenReferenceBase: def __init__(self, reference_token_idx: int = 0) -> None: self.reference_token_idx = reference_token_idx + # pyre-fixme[2]: Parameter must be annotated. def generate_reference(self, sequence_length, device: torch.device) -> torch.Tensor: r""" Generated reference tensor of given `sequence_length` using @@ -120,6 +132,8 @@ def generate_reference(self, sequence_length, device: torch.device) -> torch.Ten return torch.tensor([self.reference_token_idx] * sequence_length, device=device) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def _get_deep_layer_name(obj, layer_names): r""" Traverses through the layer names that are separated by @@ -128,6 +142,8 @@ def _get_deep_layer_name(obj, layer_names): return reduce(getattr, layer_names.split("."), obj) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def _set_deep_layer_value(obj, layer_names, value): r""" Traverses through the layer names that are separated by diff --git a/captum/attr/_models/pytext.py b/captum/attr/_models/pytext.py index 389c9ee206..f0f27fd2e7 100644 --- a/captum/attr/_models/pytext.py +++ b/captum/attr/_models/pytext.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from collections import defaultdict import torch @@ -17,11 +19,16 @@ class PyTextInterpretableEmbedding(EmbeddingBase): layer which passes precomputed embedding vectors to lower layers. """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, embeddings) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.embedding_dims = [embedding.embedding_dim for embedding in embeddings] super().__init__(sum(self.embedding_dims)) + # pyre-fixme[4]: Attribute must be annotated. self.embeddings = embeddings + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, input): r""" The forward pass of embedding layer. This can be for the text or any @@ -39,6 +46,8 @@ def forward(self, input): """ return input + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def get_attribution_map(self, attributions): r""" After attribution scores are computed for an input embedding vector @@ -84,21 +93,30 @@ class BaselineGenerator: PAD = "" + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, model, data_handler, device) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.model = model + # pyre-fixme[4]: Attribute must be annotated. self.data_handler = data_handler if "dict_feat" in data_handler.features: + # pyre-fixme[4]: Attribute must be annotated. self.vocab_dict = data_handler.features["dict_feat"].vocab if "word_feat" in data_handler.features: + # pyre-fixme[4]: Attribute must be annotated. self.vocab_word = data_handler.features["word_feat"].vocab + # pyre-fixme[4]: Attribute must be annotated. self.baseline_single_word_feature = self._generate_baseline_single_word_feature( device ) + # pyre-fixme[4]: Attribute must be annotated. self.baseline_single_dict_feature = self._generate_baseline_single_dict_feature( device ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def generate_baseline(self, integ_grads_embeddings, seq_length): r""" Generates baseline for input word and dict features. In the future we @@ -129,6 +147,8 @@ def generate_baseline(self, integ_grads_embeddings, seq_length): ) return tuple(baseline) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _generate_baseline_single_word_feature(self, device): return ( torch.tensor( @@ -138,6 +158,8 @@ def _generate_baseline_single_word_feature(self, device): .to(device) ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _generate_baseline_single_dict_feature(self, device): r"""Generate dict features based on Assistant's case study by using sia_transformer: @@ -184,9 +206,13 @@ def _generate_baseline_single_dict_feature(self, device): return (gazetteer_feat_id, gazetteer_feat_weights, gazetteer_feat_lengths) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _generate_word_baseline(self, seq_length): return self.baseline_single_word_feature.repeat(1, seq_length) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _generate_dict_baseline(self, seq_length): return ( self.baseline_single_dict_feature[0].repeat(1, seq_length), @@ -195,6 +221,8 @@ def _generate_dict_baseline(self, seq_length): ) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def configure_task_integ_grads_embeddings(task): r""" Wraps Pytext's DocNN model embedding with `IntegratedGradientsEmbedding` for @@ -219,6 +247,8 @@ def configure_task_integ_grads_embeddings(task): return integrated_gradients_embedding_lst[0] +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def configure_model_integ_grads_embeddings(model): r""" Wraps Pytext's DocNN model embedding with `IntegratedGradientsEmbedding` @@ -240,6 +270,8 @@ def configure_model_integ_grads_embeddings(model): return EmbeddingList([integrated_gradients_embedding], False) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def reshape_word_features(word_features): r""" Creates one-sample batch for word features for sanity check purposes @@ -256,8 +288,18 @@ def reshape_word_features(word_features): return word_features.unsqueeze(0) +# pyre-fixme[3]: Return type must be annotated. def reshape_dict_features( - dict_feature_id_batch, dict_weight_batch, dict_seq_len_batch, seq_length, idx + # pyre-fixme[2]: Parameter must be annotated. + dict_feature_id_batch, + # pyre-fixme[2]: Parameter must be annotated. + dict_weight_batch, + # pyre-fixme[2]: Parameter must be annotated. + dict_seq_len_batch, + # pyre-fixme[2]: Parameter must be annotated. + seq_length, + # pyre-fixme[2]: Parameter must be annotated. + idx, ): r""" Creates one-sample batch for dict features for sanity check purposes diff --git a/captum/attr/_utils/approximation_methods.py b/captum/attr/_utils/approximation_methods.py index 1c9a01ebfd..318578277f 100644 --- a/captum/attr/_utils/approximation_methods.py +++ b/captum/attr/_utils/approximation_methods.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from enum import Enum from typing import Callable, List, Tuple @@ -19,6 +21,7 @@ class Riemann(Enum): "riemann_trapezoid", ] +# pyre-fixme[5]: Global expression must be annotated. SUPPORTED_METHODS = SUPPORTED_RIEMANN_METHODS + ["gausslegendre"] @@ -123,11 +126,15 @@ def gauss_legendre_builders() -> ( def step_sizes(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from 2 to 1 + # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got + # `float`. return list(0.5 * np.polynomial.legendre.leggauss(n)[1]) def alphas(n: int) -> List[float]: assert n > 0, "The number of steps has to be larger than zero" # Scaling from [-1, 1] to [0, 1] + # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got + # `float`. return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])) return step_sizes, alphas diff --git a/captum/attr/_utils/attribution.py b/captum/attr/_utils/attribution.py index 6a777d34b8..e4f6ecbf04 100644 --- a/captum/attr/_utils/attribution.py +++ b/captum/attr/_utils/attribution.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, cast, Generic, List, Tuple, Type, Union import torch @@ -22,12 +24,15 @@ from torch.nn import Module +# pyre-fixme[13]: Attribute `attribute` is never initialized. +# pyre-fixme[13]: Attribute `compute_convergence_delta` is never initialized. class Attribution: r""" All attribution algorithms extend this class. It enforces its child classes to extend and override core `attribute` method. """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -37,6 +42,7 @@ def __init__(self, forward_func: Callable) -> None: """ self.forward_func = forward_func + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. attribute: Callable r""" This method computes and returns the attribution values for each input tensor. @@ -68,6 +74,7 @@ def __init__(self, forward_func: Callable) -> None: """ @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return False @@ -88,6 +95,7 @@ def has_convergence_delta(self) -> bool: """ return False + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. compute_convergence_delta: Callable r""" The attribution algorithms which derive `Attribution` class and provide @@ -146,6 +154,7 @@ class GradientAttribution(Attribution): that we want to interpret or the model itself. """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -155,6 +164,7 @@ def __init__(self, forward_func: Callable) -> None: function. """ Attribution.__init__(self, forward_func) + # pyre-fixme[4]: Attribute must be annotated. self.gradient_func = compute_gradients @log_usage() @@ -166,6 +176,7 @@ def compute_convergence_delta( ], end_point: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, ) -> Tensor: r""" @@ -307,6 +318,7 @@ class PerturbationAttribution(Attribution): that we want to interpret or the model itself. """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, forward_func: Callable) -> None: r""" Args: @@ -318,6 +330,7 @@ def __init__(self, forward_func: Callable) -> None: Attribution.__init__(self, forward_func) @property + # pyre-fixme[3]: Return type must be annotated. def multiplies_by_inputs(self): return True @@ -332,6 +345,7 @@ class InternalAttribution(Attribution, Generic[ModuleOrModuleList]): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, @@ -355,6 +369,7 @@ def __init__( self.device_ids = device_ids +# pyre-fixme[24]: Generic type `InternalAttribution` expects 1 type parameter. class LayerAttribution(InternalAttribution): r""" Layer attribution provides attribution values for the given layer, quantifying @@ -365,6 +380,7 @@ class LayerAttribution(InternalAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: ModuleOrModuleList, device_ids: Union[None, List[int]] = None, @@ -422,6 +438,8 @@ def interpolate( return F.interpolate(layer_attribution, interpolate_dims, mode=interpolate_mode) +# pyre-fixme[13]: Attribute `attribute` is never initialized. +# pyre-fixme[24]: Generic type `InternalAttribution` expects 1 type parameter. class NeuronAttribution(InternalAttribution): r""" Neuron attribution provides input attribution for a given neuron, quantifying @@ -435,6 +453,7 @@ class NeuronAttribution(InternalAttribution): def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_func: Callable, layer: Module, device_ids: Union[None, List[int]] = None, @@ -455,6 +474,7 @@ def __init__( """ InternalAttribution.__init__(self, forward_func, layer, device_ids) + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. attribute: Callable r""" This method computes and returns the neuron attribution values for each diff --git a/captum/attr/_utils/baselines.py b/captum/attr/_utils/baselines.py index f4b1f6d0c1..3f88efc5d1 100644 --- a/captum/attr/_utils/baselines.py +++ b/captum/attr/_utils/baselines.py @@ -1,4 +1,6 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict import random from typing import Any, Dict, List, Tuple, Union @@ -18,8 +20,10 @@ class ProductBaselines: the corresponding values. """ + # pyre-fixme[3]: Return type must be annotated. def __init__( self, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. baseline_values: Union[ List[List[Any]], Dict[Union[str, Tuple[str, ...]], List[Any]], @@ -31,9 +35,11 @@ def __init__( else: dict_keys = [] + # pyre-fixme[4]: Attribute must be annotated. self.dict_keys = dict_keys self.baseline_values = baseline_values + # pyre-fixme[3]: Return annotation cannot contain `Any`. def sample(self) -> Union[List[Any], Dict[str, Any]]: baselines = [ random.choice(baseline_list) for baseline_list in self.baseline_values @@ -52,6 +58,7 @@ def sample(self) -> Union[List[Any], Dict[str, Any]]: return dict_baselines + # pyre-fixme[3]: Return annotation cannot contain `Any`. def __call__(self) -> Union[List[Any], Dict[str, Any]]: """ Returns: diff --git a/captum/attr/_utils/batching.py b/captum/attr/_utils/batching.py index df3045525e..641314dc85 100644 --- a/captum/attr/_utils/batching.py +++ b/captum/attr/_utils/batching.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing import warnings from typing import Any, Callable, Iterator, Tuple, Union @@ -19,12 +21,19 @@ from torch import Tensor +# pyre-fixme[3]: Return type must be annotated. def _batch_attribution( + # pyre-fixme[2]: Parameter must be annotated. attr_method, + # pyre-fixme[2]: Parameter must be annotated. num_examples, + # pyre-fixme[2]: Parameter must be annotated. internal_batch_size, + # pyre-fixme[2]: Parameter must be annotated. n_steps, + # pyre-fixme[2]: Parameter must be annotated. include_endpoint=False, + # pyre-fixme[2]: Parameter must be annotated. **kwargs, ): """ @@ -101,11 +110,16 @@ def _tuple_splice_range(inputs: None, start: int, end: int) -> None: ... @typing.overload +# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. def _tuple_splice_range(inputs: Tuple, start: int, end: int) -> Tuple: ... def _tuple_splice_range( - inputs: Union[None, Tuple], start: int, end: int + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. + inputs: Union[None, Tuple], + start: int, + end: int, + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. ) -> Union[None, Tuple]: """ Splices each tensor element of given tuple (inputs) from range start @@ -123,8 +137,10 @@ def _tuple_splice_range( ) +# pyre-fixme[3]: Return annotation cannot contain `Any`. def _batched_generator( inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, target_ind: TargetType = None, internal_batch_size: Union[None, int] = None, @@ -137,6 +153,8 @@ def _batched_generator( assert internal_batch_size is None or ( isinstance(internal_batch_size, int) and internal_batch_size > 0 ), "Batch size must be greater than 0." + # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. inputs = _format_tensor_into_tuples(inputs) additional_forward_args = _format_additional_forward_args(additional_forward_args) num_examples = inputs[0].shape[0] @@ -149,13 +167,19 @@ def _batched_generator( requires_grad.""" ) if internal_batch_size is None: + # pyre-fixme[7]: Expected `Iterator[Tuple[typing.Tuple[Tensor, ...], typing.A... yield inputs, additional_forward_args, target_ind else: for current_total in range(0, num_examples, internal_batch_size): with torch.autograd.set_grad_enabled(True): inputs_splice = _tuple_splice_range( - inputs, current_total, current_total + internal_batch_size + # pyre-fixme[6]: For 1st argument expected `None` but got + # `TensorOrTupleOfTensorsGeneric`. + inputs, + current_total, + current_total + internal_batch_size, ) + # pyre-fixme[7]: Expected `Iterator[Tuple[typing.Tuple[Tensor, ...], typi... yield inputs_splice, _tuple_splice_range( additional_forward_args, current_total, @@ -171,6 +195,7 @@ def _batched_generator( def _batched_operator( operator: Callable[..., TupleOrTensorOrBoolGeneric], inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, target_ind: TargetType = None, internal_batch_size: Union[None, int] = None, @@ -195,6 +220,8 @@ def _batched_operator( return _reduce_list(all_outputs) +# pyre-fixme[3]: Return annotation cannot be `Any`. +# pyre-fixme[2]: Parameter annotation cannot be `Any`. def _select_example(curr_arg: Any, index: int, bsz: int) -> Any: if curr_arg is None: return None @@ -210,6 +237,8 @@ def _select_example(curr_arg: Any, index: int, bsz: int) -> Any: return _format_output(is_tuple, tuple(selected_arg)) +# pyre-fixme[2]: Parameter must be annotated. +# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter. def _batch_example_iterator(bsz: int, *args) -> Iterator: """ Batches the provided argument. diff --git a/captum/attr/_utils/class_summarizer.py b/captum/attr/_utils/class_summarizer.py index 63fa7c6745..085ba76148 100644 --- a/captum/attr/_utils/class_summarizer.py +++ b/captum/attr/_utils/class_summarizer.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from collections import defaultdict from typing import Any, Dict, List, Optional, Union @@ -21,6 +23,7 @@ class ClassSummarizer(Summarizer): @log_usage() def __init__(self, stats: List[Stat]) -> None: Summarizer.__init__.__wrapped__(self, stats) + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. self.summaries: Dict[Any, Summarizer] = defaultdict( lambda: Summarizer(stats=stats) ) @@ -50,10 +53,13 @@ def update( # type: ignore super().update(x) return + # pyre-fixme[9]: x has type `TensorOrTupleOfTensorsGeneric`; used as + # `Tuple[Tensor, ...]`. x = _format_tensor_into_tuples(x) num_labels = 1 + # pyre-fixme[33]: Given annotation cannot contain `Any`. labels_typed: Union[List[Any], Tensor] if isinstance(labels, list) or isinstance(labels, Tensor): labels_typed = labels @@ -82,6 +88,7 @@ def update( # type: ignore super().update(tensors_to_summarize_copy) @property + # pyre-fixme[3]: Return annotation cannot contain `Any`. def class_summaries( self, ) -> Dict[ diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 8f83a4e0c8..0687497487 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import typing from inspect import signature from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union @@ -165,6 +167,9 @@ def _format_and_verify_strides( i, strides[i], inputs[i].shape ) + # pyre-fixme[7]: Expected `Tuple[Union[int, typing.Tuple[int, ...]], ...]` but + # got `Union[Tuple[Union[int, typing.Tuple[Union[int, typing.Tuple[int, ...]], + # ...]]], typing.Tuple[Union[int, typing.Tuple[int, ...]], ...]]`. return strides @@ -176,6 +181,7 @@ def _format_and_verify_sliding_window_shapes( # Assumes inputs is already formatted (in tuple) if isinstance(sliding_window_shapes[0], int): sliding_window_shapes = (sliding_window_shapes,) # type: ignore + # pyre-fixme[35]: Target cannot be annotated. sliding_window_shapes: Tuple[Tuple[int, ...], ...] assert len(sliding_window_shapes) == len( inputs @@ -194,19 +200,27 @@ def _format_and_verify_sliding_window_shapes( @typing.overload +# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does +# not accept all possible arguments of overload defined on line `199`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, attributions: Tuple[Tensor, ...], start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], end_point: Union[Tensor, Tuple[Tensor, ...]], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, target: TargetType, + # pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[False] = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... @typing.overload +# pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does +# not accept all possible arguments of overload defined on line `212`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, @@ -215,6 +229,8 @@ def _compute_conv_delta_and_format_attrs( end_point: Union[Tensor, Tuple[Tensor, ...]], additional_forward_args: Any, target: TargetType, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. is_inputs_tuple: Literal[True], ) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ... @@ -250,6 +266,8 @@ def _compute_conv_delta_and_format_attrs( def _tensorize_baseline( inputs: Tuple[Tensor, ...], baselines: Tuple[Union[int, float, Tensor], ...] ) -> Tuple[Tensor, ...]: + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _tensorize_single_baseline(baseline, input): if isinstance(baseline, (int, float)): return torch.full_like(input, baseline) diff --git a/captum/attr/_utils/custom_modules.py b/captum/attr/_utils/custom_modules.py index 8dea72054f..a666cfce6a 100644 --- a/captum/attr/_utils/custom_modules.py +++ b/captum/attr/_utils/custom_modules.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import torch.nn as nn @@ -10,5 +12,7 @@ class Addition_Module(nn.Module): def __init__(self) -> None: super().__init__() + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, x1, x2): return x1 + x2 diff --git a/captum/attr/_utils/input_layer_wrapper.py b/captum/attr/_utils/input_layer_wrapper.py index 402319fb43..3d2855b77d 100644 --- a/captum/attr/_utils/input_layer_wrapper.py +++ b/captum/attr/_utils/input_layer_wrapper.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + import inspect from typing import Any @@ -19,6 +21,8 @@ def __init__(self, input_name: str) -> None: super().__init__() self.input_name = input_name + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, x): return x @@ -60,11 +64,14 @@ def __init__(self, module_to_wrap: nn.Module) -> None: self.module = module_to_wrap # ignore self + # pyre-fixme[4]: Attribute must be annotated. self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:] self.input_maps = nn.ModuleDict( {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} ) + # pyre-fixme[3]: Return annotation cannot be `Any`. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, *args, **kwargs) -> Any: args = list(args) for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)): diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 6e7d8d8326..fe4e995fc6 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -1,3 +1,4 @@ +# pyre-strict from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -22,6 +23,8 @@ def _scatter_itp_attr_by_mask( # input_shape in shape(batch_size, *inp_feature_dims) # attribute in shape(*output_dims, *inp_feature_dims) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples: + # `*output_dims, *input_shape[slice(1, None, None)]`. attr_shape = (*output_dims, *input_shape[1:]) expanded_feature_indices = mask.expand(attr_shape) @@ -34,7 +37,11 @@ def _scatter_itp_attr_by_mask( # (*output_dims, 1..., 1, n_itp_features) # then broadcast to (*output_dims, *inp.shape[1:-1], n_itp_features) n_extra_dims = len(extra_inp_dims) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*output_dims, *(1).__mul__(n_extra_dims)`. unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_itp_features) + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + # tuples: `*output_dims, *extra_inp_dims`. expanded_shape = (*output_dims, *extra_inp_dims, n_itp_features) expanded_itp_attr = itp_attr.reshape(unsqueezed_shape).expand(expanded_shape) else: @@ -107,6 +114,7 @@ def to_tensor(self) -> Tensor: pass @abstractmethod + # pyre-fixme[3]: Return annotation cannot be `Any`. def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any: """ Get the (perturbed) input in the format required by the model @@ -188,10 +196,13 @@ class TextTemplateInput(InterpretableInput): """ + # pyre-fixme[3]: Return type must be annotated. def __init__( self, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. template: Union[str, Callable], values: Union[List[str], Dict[str, str]], + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. baselines: Union[List[str], Dict[str, str], Callable, None] = None, mask: Union[List[int], Dict[str, int], None] = None, ): @@ -206,6 +217,7 @@ def __init__( dict_keys = [] self.values = values + # pyre-fixme[4]: Attribute must be annotated. self.dict_keys = dict_keys n_features = len(values) @@ -249,12 +261,15 @@ def __init__( # internal compressed mask of continuous interpretable indices from 0 # cannot replace original mask of ids for grouping across values externally + # pyre-fixme[4]: Attribute must be annotated. self.formatted_mask = [mask_id_to_idx[mid] for mid in mask] n_itp_features = len(mask_ids) # number of raw features and intepretable features + # pyre-fixme[4]: Attribute must be annotated. self.n_features = n_features + # pyre-fixme[4]: Attribute must be annotated. self.n_itp_features = n_itp_features if isinstance(template, str): @@ -265,6 +280,7 @@ def __init__( f"received: {type(template)}" ) template = template + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. self.format_fn = template self.mask = mask @@ -273,6 +289,8 @@ def to_tensor(self) -> torch.Tensor: # Interpretable representation in shape(1, n_itp_features) return torch.tensor([[1.0] * self.n_itp_features]) + # pyre-fixme[14]: `to_model_input` overrides method defined in + # `InterpretableInput` inconsistently. def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str: values = list(self.values) # clone @@ -303,12 +321,18 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str: itp_val = perturbed_tensor[0][itp_idx] if not itp_val: + # pyre-fixme[16]: Item `None` of `Union[None, Dict[str, str], + # List[typing.Any]]` has no attribute `__getitem__`. values[i] = baselines[i] if self.dict_keys: dict_values = dict(zip(self.dict_keys, values)) + # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not + # a function. input_str = self.format_fn(**dict_values) else: + # pyre-fixme[29]: `Union[typing.Callable[..., typing.Any], str]` is not + # a function. input_str = self.format_fn(*values) return input_str @@ -367,9 +391,11 @@ class TextTokenInput(InterpretableInput): """ + # pyre-fixme[3]: Return type must be annotated. def __init__( self, text: str, + # pyre-fixme[2]: Parameter must be annotated. tokenizer, baselines: Union[int, str] = 0, # usually UNK skip_tokens: Union[List[int], List[str], None] = None, @@ -377,10 +403,13 @@ def __init__( inp_tensor = tokenizer.encode(text, return_tensors="pt") # input tensor into the model of token ids + # pyre-fixme[4]: Attribute must be annotated. self.inp_tensor = inp_tensor # tensor of interpretable token ids + # pyre-fixme[4]: Attribute must be annotated. self.itp_tensor = inp_tensor # interpretable mask + # pyre-fixme[4]: Attribute must be annotated. self.itp_mask = None if skip_tokens: @@ -401,10 +430,14 @@ def __init__( self.skip_tokens = skip_tokens # features values, the tokens + # pyre-fixme[4]: Attribute must be annotated. self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0].tolist()) + # pyre-fixme[4]: Attribute must be annotated. self.tokenizer = tokenizer + # pyre-fixme[4]: Attribute must be annotated. self.n_itp_features = len(self.values) + # pyre-fixme[4]: Attribute must be annotated. self.baselines = ( baselines if type(baselines) is int @@ -415,6 +448,9 @@ def to_tensor(self) -> torch.Tensor: # return the perturbation indicator as interpretable tensor instead of token ids return torch.ones_like(self.itp_tensor) + # pyre-fixme[14]: `to_model_input` overrides method defined in + # `InterpretableInput` inconsistently. + # pyre-fixme[2]: Parameter must be annotated. def to_model_input(self, perturbed_tensor=None) -> torch.Tensor: if perturbed_tensor is None: return self.inp_tensor @@ -440,5 +476,6 @@ def to_model_input(self, perturbed_tensor=None) -> torch.Tensor: return perturb_inp_tensor + # pyre-fixme[3]: Return type must be annotated. def format_attr(self, itp_attr: torch.Tensor): return itp_attr diff --git a/captum/attr/_utils/lrp_rules.py b/captum/attr/_utils/lrp_rules.py index b244ce124b..a638aba17b 100644 --- a/captum/attr/_utils/lrp_rules.py +++ b/captum/attr/_utils/lrp_rules.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from abc import ABC, abstractmethod import torch @@ -15,22 +17,32 @@ class PropagationRule(ABC): STABILITY_FACTOR = 1e-9 + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_hook(self, module, inputs, outputs): """Register backward hooks on input and output tensors of linear layers in the model.""" inputs = _format_tensor_into_tuples(inputs) + # pyre-fixme[16]: `PropagationRule` has no attribute `_has_single_input`. + # pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]` + # but got `None`. self._has_single_input = len(inputs) == 1 + # pyre-fixme[16]: `PropagationRule` has no attribute `_handle_input_hooks`. self._handle_input_hooks = [] + # pyre-fixme[16]: `None` has no attribute `__iter__`. for input in inputs: if not hasattr(input, "hook_registered"): input_hook = self._create_backward_hook_input(input.data) self._handle_input_hooks.append(input.register_hook(input_hook)) input.hook_registered = True output_hook = self._create_backward_hook_output(outputs.data) + # pyre-fixme[16]: `PropagationRule` has no attribute `_handle_output_hook`. self._handle_output_hook = outputs.register_hook(output_hook) return outputs.clone() @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def backward_hook_activation(module, grad_input, grad_output): """Backward hook to propagate relevance over non-linear activations.""" # replace_out is set in _backward_hook_input, this is necessary @@ -41,11 +53,18 @@ def backward_hook_activation(module, grad_input, grad_output): return hook_out return grad_output + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _create_backward_hook_input(self, inputs): + # pyre-fixme[53]: Captured variable `inputs` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _backward_hook_input(grad): relevance = grad * inputs device = grad.device + # pyre-fixme[16]: `PropagationRule` has no attribute `_has_single_input`. if self._has_single_input: + # pyre-fixme[16]: `PropagationRule` has no attribute `relevance_input`. self.relevance_input[device] = relevance.data else: self.relevance_input[device].append(relevance.data) @@ -57,16 +76,24 @@ def _backward_hook_input(grad): return _backward_hook_input + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _create_backward_hook_output(self, outputs): + # pyre-fixme[53]: Captured variable `outputs` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _backward_hook_output(grad): sign = torch.sign(outputs) sign[sign == 0] = 1 relevance = grad / (outputs + sign * self.STABILITY_FACTOR) + # pyre-fixme[16]: `PropagationRule` has no attribute `relevance_output`. self.relevance_output[grad.device] = grad.data return relevance return _backward_hook_output + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_hook_weights(self, module, inputs, outputs): """Save initial activations a_j before modules are changed""" device = inputs[0].device if isinstance(inputs, tuple) else inputs.device @@ -81,9 +108,13 @@ def forward_hook_weights(self, module, inputs, outputs): self._manipulate_weights(module, inputs, outputs) @abstractmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _manipulate_weights(self, module, inputs, outputs): raise NotImplementedError + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward_pre_hook_activations(self, module, inputs): """Pass initial activations to graph generation pass""" device = inputs[0].device if isinstance(inputs, tuple) else inputs.device @@ -104,9 +135,13 @@ class EpsilonRule(PropagationRule): discriminator during propagation. """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, epsilon=1e-9) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.STABILITY_FACTOR = epsilon + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _manipulate_weights(self, module, inputs, outputs): pass @@ -123,10 +158,15 @@ class GammaRule(PropagationRule): the positive relevance is increased. """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.gamma = gamma + # pyre-fixme[4]: Attribute must be annotated. self.set_bias_to_zero = set_bias_to_zero + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _manipulate_weights(self, module, inputs, outputs): if hasattr(module, "weight"): module.weight.data = ( @@ -149,9 +189,13 @@ class Alpha1_Beta0_Rule(PropagationRule): Use for lower layers. """ + # pyre-fixme[2]: Parameter must be annotated. def __init__(self, set_bias_to_zero=False) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.set_bias_to_zero = set_bias_to_zero + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _manipulate_weights(self, module, inputs, outputs): if hasattr(module, "weight"): module.weight.data = module.weight.data.clamp(min=0) @@ -169,8 +213,13 @@ class IdentityRule(EpsilonRule): Can be used for BatchNorm2D. """ + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _create_backward_hook_input(self, inputs): + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _backward_hook_input(grad): + # pyre-fixme[16]: `IdentityRule` has no attribute `relevance_output`. return self.relevance_output[grad.device] return _backward_hook_input diff --git a/captum/attr/_utils/stat.py b/captum/attr/_utils/stat.py index 803bbc7ab7..70bfe47c7c 100644 --- a/captum/attr/_utils/stat.py +++ b/captum/attr/_utils/stat.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict from typing import Any, Callable, List, Optional, TYPE_CHECKING import torch @@ -29,11 +31,13 @@ def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None: kwargs (Any): Additional arguments used to construct the statistic """ + # pyre-fixme[4]: Attribute must be annotated. self.params = kwargs self._name = name self._other_stats: Optional[SummarizerSingleTensor] = None + # pyre-fixme[3]: Return type must be annotated. def init(self): pass @@ -41,12 +45,14 @@ def _get_stat(self, stat: "Stat") -> Optional["Stat"]: assert self._other_stats is not None return self._other_stats.get(stat) + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Tensor): raise NotImplementedError() def get(self) -> Optional[Tensor]: raise NotImplementedError() + # pyre-fixme[3]: Return type must be annotated. def __hash__(self): return hash((self.__class__, frozenset(self.params.items()))) @@ -62,6 +68,7 @@ def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property + # pyre-fixme[3]: Return type must be annotated. def name(self): """ The name of the statistic. i.e. it is the key in a .summary @@ -85,11 +92,15 @@ class Count(Stat): def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) + # pyre-fixme[4]: Attribute must be annotated. self.n = None + # pyre-fixme[3]: Return type must be annotated. def get(self): return self.n + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def update(self, x): if self.n is None: self.n = 0 @@ -109,10 +120,15 @@ def __init__(self, name: Optional[str] = None) -> None: def get(self) -> Optional[Tensor]: return self.rolling_mean + # pyre-fixme[3]: Return type must be annotated. def init(self): + # pyre-fixme[8]: Attribute has type `Optional[Count]`; used as `Optional[Stat]`. self.n = self._get_stat(Count()) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def update(self, x): + # pyre-fixme[16]: `Optional` has no attribute `get`. n = self.n.get() if self.rolling_mean is None: @@ -120,6 +136,7 @@ def update(self, x): self.rolling_mean = x.clone() if x.is_floating_point() else x.double() else: delta = x - self.rolling_mean + # pyre-fixme[16]: `Optional` has no attribute `__iadd__`. self.rolling_mean += delta / n @@ -130,10 +147,14 @@ class MSE(Stat): def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) + # pyre-fixme[4]: Attribute must be annotated. self.prev_mean = None + # pyre-fixme[4]: Attribute must be annotated. self.mse = None + # pyre-fixme[3]: Return type must be annotated. def init(self): + # pyre-fixme[16]: `MSE` has no attribute `mean`. self.mean = self._get_stat(Mean()) def get(self) -> Optional[Tensor]: @@ -141,7 +162,9 @@ def get(self) -> Optional[Tensor]: return torch.zeros_like(self.prev_mean) return self.mse + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Tensor): + # pyre-fixme[16]: `MSE` has no attribute `mean`. mean = self.mean.get() if mean is not None and self.prev_mean is not None: @@ -175,15 +198,21 @@ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: super().__init__(name=name, order=order) self.order = order + # pyre-fixme[3]: Return type must be annotated. def init(self): + # pyre-fixme[16]: `Var` has no attribute `mse`. self.mse = self._get_stat(MSE()) + # pyre-fixme[16]: `Var` has no attribute `n`. self.n = self._get_stat(Count()) + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Tensor): pass def get(self) -> Optional[Tensor]: + # pyre-fixme[16]: `Var` has no attribute `mse`. mse = self.mse.get() + # pyre-fixme[16]: `Var` has no attribute `n`. n = self.n.get() if mse is None: @@ -215,13 +244,17 @@ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: super().__init__(name=name, order=order) self.order = order + # pyre-fixme[3]: Return type must be annotated. def init(self): + # pyre-fixme[16]: `StdDev` has no attribute `var`. self.var = self._get_stat(Var(order=self.order)) + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Tensor): pass def get(self) -> Optional[Tensor]: + # pyre-fixme[16]: `StdDev` has no attribute `var`. var = self.var.get() return var**0.5 if var is not None else None @@ -232,14 +265,18 @@ class GeneralAccumFn(Stat): where fn is a custom function """ + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, fn: Callable, name: Optional[str] = None) -> None: super().__init__(name=name) + # pyre-fixme[4]: Attribute must be annotated. self.result = None self.fn = fn def get(self) -> Optional[Tensor]: return self.result + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def update(self, x): if self.result is None: self.result = x @@ -249,21 +286,30 @@ def update(self, x): class Min(GeneralAccumFn): def __init__( - self, name: Optional[str] = None, min_fn: Callable = torch.min + self, + name: Optional[str] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + min_fn: Callable = torch.min, ) -> None: super().__init__(name=name, fn=min_fn) class Max(GeneralAccumFn): def __init__( - self, name: Optional[str] = None, max_fn: Callable = torch.max + self, + name: Optional[str] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + max_fn: Callable = torch.max, ) -> None: super().__init__(name=name, fn=max_fn) class Sum(GeneralAccumFn): def __init__( - self, name: Optional[str] = None, add_fn: Callable = torch.add + self, + name: Optional[str] = None, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + add_fn: Callable = torch.add, ) -> None: super().__init__(name=name, fn=add_fn) diff --git a/captum/attr/_utils/summarizer.py b/captum/attr/_utils/summarizer.py index e4c5c860a0..8a0944f70b 100644 --- a/captum/attr/_utils/summarizer.py +++ b/captum/attr/_utils/summarizer.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +# pyre-strict + from typing import Dict, List, Optional, Tuple, Type, Union import torch @@ -35,13 +37,16 @@ def __init__(self, stats: List[Stat]) -> None: """ self._summarizers: List[SummarizerSingleTensor] = [] self._is_inputs_tuple: Optional[bool] = None + # pyre-fixme[4]: Attribute must be annotated. self._stats, self._summary_stats_indicies = _reorder_stats(stats) + # pyre-fixme[3]: Return type must be annotated. def _copy_stats(self): import copy return copy.deepcopy(self._stats) + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): r""" Calls `update` on each `Stat` object within the summarizer @@ -121,11 +126,14 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: dep_order = [StdDev, Var, MSE, Mean, Count] # remove dupe stats + # pyre-fixme[9]: stats has type `List[Stat]`; used as `Set[Stat]`. stats = set(stats) summary_stats = set(stats) from collections import defaultdict + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type[]` to avoid runtime subscripting errors. stats_by_module: Dict[Type, List[Stat]] = defaultdict(list) for stat in stats: stats_by_module[stat.__class__].append(stat) @@ -134,6 +142,7 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: # for each StdDev(order) we must ensure there is an associated Var(order) for std_dev in stats_by_module[StdDev]: stat_to_add = Var(order=std_dev.order) # type: ignore + # pyre-fixme[16]: `List` has no attribute `add`. stats.add(stat_to_add) stats_by_module[stat_to_add.__class__].append(stat_to_add) @@ -141,14 +150,21 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: # we want to ensure i...n-1 exists for i, dep in enumerate(dep_order[1:]): if dep in stats_by_module: + # pyre-fixme[16]: `List` has no attribute `update`. stats.update([mod() for mod in dep_order[i + 1 :]]) break # Step 2: get the correct order # NOTE: we are sorting via a given topological order sort_order = {mod: i for i, mod in enumerate(dep_order)} + # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev, + # Var]]` but got `Type[Min]`. sort_order[Min] = -1 + # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev, + # Var]]` but got `Type[Max]`. sort_order[Max] = -1 + # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev, + # Var]]` but got `Type[Sum]`. sort_order[Sum] = -1 stats = list(stats) @@ -181,13 +197,16 @@ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None: does not require any specific order. """ self._stats = stats + # pyre-fixme[4]: Attribute must be annotated. self._stat_to_stat = {stat: stat for stat in self._stats} + # pyre-fixme[4]: Attribute must be annotated. self._summary_stats = [stats[i] for i in summary_stats_indices] for stat in stats: stat._other_stats = self stat.init() + # pyre-fixme[3]: Return type must be annotated. def update(self, x: Tensor): r""" Updates the summary of a given tensor `x` diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 9782dc63cb..0f67043bff 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 + +# pyre-strict import warnings from enum import Enum from typing import Any, Iterable, List, Optional, Tuple, Union @@ -43,10 +45,12 @@ class VisualizeSign(Enum): all = 4 +# pyre-fixme[3]: Return type must be annotated. def _prepare_image(attr_visual: ndarray): return np.clip(attr_visual.astype(int), 0, 255) +# pyre-fixme[3]: Return type must be annotated. def _normalize_scale(attr: ndarray, scale_factor: float): assert scale_factor != 0, "Cannot normalize by scale factor = 0" if abs(scale_factor) < 1e-5: @@ -60,6 +64,7 @@ def _normalize_scale(attr: ndarray, scale_factor: float): return np.clip(attr_norm, -1, 1) +# pyre-fixme[3]: Return type must be annotated. def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]): # given values should be non-negative assert percentile >= 0 and percentile <= 100, ( @@ -71,6 +76,7 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]): return sorted_vals[threshold_id] +# pyre-fixme[3]: Return type must be annotated. def _normalize_attr( attr: ndarray, sign: str, @@ -346,6 +352,7 @@ def visualize_image_attr( return plt_fig, plt_axis +# pyre-fixme[3]: Return type must be annotated. def visualize_image_attr_multiple( attr: ndarray, original_image: Union[None, ndarray], @@ -455,6 +462,7 @@ def visualize_image_attr_multiple( return plt_fig, plt_axis +# pyre-fixme[3]: Return type must be annotated. def visualize_timeseries_attr( attr: ndarray, data: ndarray, @@ -471,6 +479,7 @@ def visualize_timeseries_attr( title: Optional[str] = None, fig_size: Tuple[int, int] = (6, 6), use_pyplot: bool = True, + # pyre-fixme[2]: Parameter must be annotated. **pyplot_kwargs, ): r""" @@ -590,9 +599,11 @@ def visualize_timeseries_attr( # Check input dimensions assert len(attr.shape) == 2, "Expected attr of shape (N, C), got {}".format( + # pyre-fixme[16]: Module `attr` has no attribute `shape`. attr.shape ) assert len(data.shape) == 2, "Expected data of shape (N, C), got {}".format( + # pyre-fixme[16]: Module `attr` has no attribute `shape`. attr.shape ) @@ -667,7 +678,11 @@ def visualize_timeseries_attr( cmap = cm.get_cmap(cmap) # type: ignore cm_norm = colors.Normalize(vmin, vmax) + # pyre-fixme[53]: Captured variable `cm_norm` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): + # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. half_col_width = (x_values[1] - x_values[0]) / 2.0 for icol, col_center in enumerate(x_vals): left = col_center - half_col_width @@ -675,6 +690,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax): ax.axvspan( xmin=left, xmax=right, + # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function. facecolor=(cmap(cm_norm(attr_vals[icol]))), edgecolor=None, alpha=alpha_overlay, @@ -779,25 +795,43 @@ class VisualizationDataRecord: def __init__( self, + # pyre-fixme[2]: Parameter must be annotated. word_attributions, + # pyre-fixme[2]: Parameter must be annotated. pred_prob, + # pyre-fixme[2]: Parameter must be annotated. pred_class, + # pyre-fixme[2]: Parameter must be annotated. true_class, + # pyre-fixme[2]: Parameter must be annotated. attr_class, + # pyre-fixme[2]: Parameter must be annotated. attr_score, + # pyre-fixme[2]: Parameter must be annotated. raw_input_ids, + # pyre-fixme[2]: Parameter must be annotated. convergence_score, ) -> None: + # pyre-fixme[4]: Attribute must be annotated. self.word_attributions = word_attributions + # pyre-fixme[4]: Attribute must be annotated. self.pred_prob = pred_prob + # pyre-fixme[4]: Attribute must be annotated. self.pred_class = pred_class + # pyre-fixme[4]: Attribute must be annotated. self.true_class = true_class + # pyre-fixme[4]: Attribute must be annotated. self.attr_class = attr_class + # pyre-fixme[4]: Attribute must be annotated. self.attr_score = attr_score + # pyre-fixme[4]: Attribute must be annotated. self.raw_input_ids = raw_input_ids + # pyre-fixme[4]: Attribute must be annotated. self.convergence_score = convergence_score +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def _get_color(attr): # clip values to prevent CSS errors (Values should be from [-1,1]) attr = max(-1, min(1, attr)) @@ -812,16 +846,22 @@ def _get_color(attr): return "hsl({}, {}%, {}%)".format(hue, sat, lig) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def format_classname(classname): return '{}'.format(classname) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def format_special_tokens(token): if token.startswith("<") and token.endswith(">"): return "#" + token.strip("<>") return token +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def format_tooltip(item, text): return '
{item}\ {text}\ @@ -830,6 +870,8 @@ def format_tooltip(item, text): ) +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. def format_word_importances(words, importances): if importances is None or len(importances) == 0: return ""