diff --git a/captum/_utils/common.py b/captum/_utils/common.py index f834394b21..58968d747c 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -363,10 +363,9 @@ def _expand_target( return target -# pyre-fixme[3]: Return type must be annotated. def _expand_feature_mask( feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int -): +) -> Tuple[Tensor, ...]: # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, # typing.Tuple[Tensor, ...]]`. is_feature_mask_tuple = _is_tuple(feature_mask) @@ -379,10 +378,9 @@ def _expand_feature_mask( ) for feature_mask_elem in feature_mask ) - return _format_output(is_feature_mask_tuple, feature_mask_new) + return _format_output(is_feature_mask_tuple, feature_mask_new) # type: ignore -# pyre-fixme[3]: Return type must be annotated. def _expand_and_update_baselines( inputs: Tuple[Tensor, ...], n_samples: int, @@ -390,7 +388,7 @@ def _expand_and_update_baselines( # `typing.Dict[, ]` to avoid runtime subscripting errors. kwargs: dict, draw_baseline_from_distrib: bool = False, -): +) -> None: # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def get_random_baseline_indices(bsz, baseline): @@ -432,10 +430,9 @@ def get_random_baseline_indices(bsz, baseline): kwargs["baselines"] = baselines -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict[, ]` to avoid runtime subscripting errors. -def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): +def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict) -> None: if "additional_forward_args" not in kwargs: return additional_forward_args = kwargs["additional_forward_args"] @@ -451,10 +448,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): kwargs["additional_forward_args"] = additional_forward_args -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict[, ]` to avoid runtime subscripting errors. -def _expand_and_update_target(n_samples: int, kwargs: dict): +def _expand_and_update_target(n_samples: int, kwargs: dict) -> None: if "target" not in kwargs: return target = kwargs["target"] @@ -465,10 +461,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict): kwargs["target"] = target -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use # `typing.Dict[, ]` to avoid runtime subscripting errors. -def _expand_and_update_feature_mask(n_samples: int, kwargs: dict): +def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None: if "feature_mask" not in kwargs: return @@ -573,10 +568,9 @@ def _format_outputs( # pyre-fixme[24] Callable requires 2 arguments def _construct_future_forward(original_forward: Callable) -> Callable: # pyre-fixme[3] return type not specified - # pyre-ignore - def future_forward(*args, **kwargs): - # pyre-ignore - fut = torch.futures.Future() + def future_forward(*args: Any, **kwargs: Any): + # pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function. + fut: torch.futures.Future[Tensor] = torch.futures.Future() fut.set_result(original_forward(*args, **kwargs)) return fut @@ -921,8 +915,7 @@ def input_tensor_hook(input_grad: Tensor): ] -# pyre-fixme[3]: Return type must be annotated. -def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]): +def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int: """ Returns the max feature mask index The feature mask should be formatted to tuple of tensors at first. diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 7c9104d88c..d24df1fff6 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -986,7 +986,11 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick( out = loss sample_grad_wrapper.compute_param_sample_gradients( - out, loss_mode=reduction_type + out, + # pyre-fixme[6]: In call `SampleGradientWrapper. + # compute_param_sample_gradients`, for argument `loss_mode`, + # expected `str` but got `Optional[str]`. + loss_mode=reduction_type, # type: ignore ) if layer_modules is not None: layer_parameters = _extract_parameters_from_layers(layer_modules) diff --git a/captum/_utils/models/linear_model/model.py b/captum/_utils/models/linear_model/model.py index 6008fe983d..08ec2442f9 100644 --- a/captum/_utils/models/linear_model/model.py +++ b/captum/_utils/models/linear_model/model.py @@ -41,7 +41,6 @@ def __init__(self, train_fn: Callable, **kwargs) -> None: # pyre-fixme[4]: Attribute must be annotated. self.construct_kwargs = kwargs - # pyre-fixme[3]: Return type must be annotated. def _construct_model_params( self, in_features: Optional[int] = None, @@ -52,7 +51,7 @@ def _construct_model_params( weight_values: Optional[Tensor] = None, bias_value: Optional[Tensor] = None, classes: Optional[Tensor] = None, - ): + ) -> None: r""" Lazily initializes a linear model. This will be called for you in a train method. diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 64d79153f1..2ba70ea32f 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -9,9 +9,8 @@ from torch.utils.data import DataLoader -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def l2_loss(x1, x2, weights=None): +def l2_loss(x1, x2, weights=None) -> torch.Tensor: if weights is None: return torch.mean((x1 - x2) ** 2) / 2.0 else: @@ -236,7 +235,7 @@ def get_point(datapoint): class NormLayer(nn.Module): # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, mean, std, n=None, eps=1e-8) -> None: + def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None: super().__init__() # pyre-fixme[4]: Attribute must be annotated. self.mean = mean @@ -251,7 +250,6 @@ def forward(self, x): return (x - self.mean) / (self.std + self.eps) -# pyre-fixme[3]: Return type must be annotated. def sklearn_train_linear_model( model: LinearModel, dataloader: DataLoader, @@ -260,7 +258,7 @@ def sklearn_train_linear_model( norm_input: bool = False, # pyre-fixme[2]: Parameter must be annotated. **fit_kwargs, -): +) -> Dict[str, float]: r""" Alternative method to train with sklearn. This does introduce some slight overhead as we convert the tensors to numpy and then convert the resulting diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index 1e5891cc80..0e2a42d3a7 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -5,7 +5,7 @@ import sys import warnings from time import time -from typing import cast, Iterable, Optional, Sized, TextIO +from typing import Any, cast, Iterable, Optional, Sized, TextIO from captum._utils.typing import Literal @@ -61,15 +61,17 @@ class NullProgress: progress bars. """ - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs): + def __init__( + self, + # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. + iterable: Optional[Iterable] = None, + *args: Any, + **kwargs: Any, + ) -> None: del args, kwargs self.iterable = iterable - # pyre-fixme[3]: Return type must be annotated. - def __enter__(self): + def __enter__(self) -> "NullProgress": return self # pyre-fixme[2]: Parameter must be annotated. @@ -87,12 +89,10 @@ def __iter__(self): for it in self.iterable: yield it - # pyre-fixme[3]: Return type must be annotated. - def update(self, amount: int = 1): + def update(self, amount: int = 1) -> None: pass - # pyre-fixme[3]: Return type must be annotated. - def close(self): + def close(self) -> None: pass @@ -133,8 +133,7 @@ def __init__( self.closed = False self._is_parent = False - # pyre-fixme[3]: Return type must be annotated. - def __enter__(self): + def __enter__(self) -> "SimpleProgress": self._is_parent = True self._refresh() return self @@ -158,8 +157,7 @@ def __iter__(self): self.update() self.close() - # pyre-fixme[3]: Return type must be annotated. - def _refresh(self): + def _refresh(self) -> None: progress_str = self.desc + ": " if self.desc else "" if self.total: # e.g., progress: 60% 3/5 @@ -172,8 +170,7 @@ def _refresh(self): end = "\n" if self._is_parent else "" print("\r" + progress_str, end=end, file=self.file) - # pyre-fixme[3]: Return type must be annotated. - def update(self, amount: int = 1): + def update(self, amount: int = 1) -> None: if self.closed: return self.cur += amount @@ -183,8 +180,7 @@ def update(self, amount: int = 1): self._refresh() self.last_print_t = cur_t - # pyre-fixme[3]: Return type must be annotated. - def close(self): + def close(self) -> None: if not self.closed and not self._is_parent: self._refresh() print(file=self.file) # end with new line @@ -197,8 +193,7 @@ def progress( iterable: Optional[Iterable] = None, desc: Optional[str] = None, total: Optional[int] = None, - # pyre-fixme[2]: Parameter must be annotated. - use_tqdm=True, + use_tqdm: bool = True, file: Optional[TextIO] = None, mininterval: float = 0.5, # pyre-fixme[2]: Parameter must be annotated. diff --git a/captum/_utils/sample_gradient.py b/captum/_utils/sample_gradient.py index 7b868b9cd8..c5c15d867b 100644 --- a/captum/_utils/sample_gradient.py +++ b/captum/_utils/sample_gradient.py @@ -103,7 +103,7 @@ class SampleGradientWrapper: """ # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, model, layer_modules=None) -> None: + def __init__(self, model, layer_modules: Optional[List[Module]] = None) -> None: # pyre-fixme[4]: Attribute must be annotated. self.model = model self.hooks_added = False @@ -162,8 +162,9 @@ def _reset(self) -> None: self.activation_dict = defaultdict(list) self.gradient_dict = defaultdict(list) - # pyre-fixme[2]: Parameter must be annotated. - def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None: + def compute_param_sample_gradients( + self, loss_blob: Tensor, loss_mode: str = "mean" + ) -> None: assert ( loss_mode.upper() in LossMode.__members__ ), f"Provided loss mode {loss_mode} is not valid" diff --git a/captum/attr/_core/dataloader_attr.py b/captum/attr/_core/dataloader_attr.py index cc63c0af08..444db1ae3c 100644 --- a/captum/attr/_core/dataloader_attr.py +++ b/captum/attr/_core/dataloader_attr.py @@ -30,9 +30,8 @@ class InputRole: # default reducer wehn reduce is None. Simply concat the outputs by the batch dimension -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def _concat_tensors(accum, cur_output, _): +def _concat_tensors(accum, cur_output, _) -> Tensor: return cur_output if accum is None else torch.cat([accum, cur_output]) @@ -185,7 +184,6 @@ def __init__(self, attr_method: Attribution) -> None: self.attr_method.forward_func = self._forward_with_dataloader - # pyre-fixme[3]: Return type must be annotated. def _forward_with_dataloader( self, batched_perturbed_feature_indices: Tensor, @@ -199,7 +197,7 @@ def _forward_with_dataloader( to_metric: Optional[Callable], show_progress: bool, feature_idx_to_mask_idx: Dict[int, List[int]], - ): + ) -> Tensor: """ Wrapper of the original given forward_func to be used in the attribution method It iterates over the dataloader with the given forward_func @@ -468,10 +466,8 @@ def attribute( return _format_output(is_inputs_tuple, attr) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for DataLoaderAttribution. """ diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index e49a8988de..24f44788b2 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -388,10 +388,8 @@ def attribute( # type: ignore is_inputs_tuple, ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for DeepLift. """ @@ -587,8 +585,7 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @@ -985,7 +982,6 @@ def nonlinear( return new_grad_inp -# pyre-fixme[3]: Return type must be annotated. def softmax( module: Module, inputs: Tensor, @@ -993,7 +989,7 @@ def softmax( grad_input: Tensor, grad_output: Tensor, eps: float = 1e-10, -): +) -> Tensor: delta_in, delta_out = _compute_diffs(inputs, outputs) grad_input_unnorm = torch.where( @@ -1007,7 +1003,6 @@ def softmax( return new_grad_inp -# pyre-fixme[3]: Return type must be annotated. def maxpool1d( module: Module, inputs: Tensor, @@ -1015,7 +1010,7 @@ def maxpool1d( grad_input: Tensor, grad_output: Tensor, eps: float = 1e-10, -): +) -> Tensor: return maxpool( module, F.max_pool1d, @@ -1028,7 +1023,6 @@ def maxpool1d( ) -# pyre-fixme[3]: Return type must be annotated. def maxpool2d( module: Module, inputs: Tensor, @@ -1036,7 +1030,7 @@ def maxpool2d( grad_input: Tensor, grad_output: Tensor, eps: float = 1e-10, -): +) -> Tensor: return maxpool( module, F.max_pool2d, @@ -1049,7 +1043,6 @@ def maxpool2d( ) -# pyre-fixme[3]: Return type must be annotated. def maxpool3d( module: Module, # pyre-fixme[2]: Parameter must be annotated. @@ -1061,7 +1054,7 @@ def maxpool3d( # pyre-fixme[2]: Parameter must be annotated. grad_output, eps: float = 1e-10, -): +) -> Tensor: return maxpool( module, F.max_pool3d, @@ -1074,7 +1067,6 @@ def maxpool3d( ) -# pyre-fixme[3]: Return type must be annotated. def maxpool( module: Module, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. @@ -1090,7 +1082,7 @@ def maxpool( # pyre-fixme[2]: Parameter must be annotated. grad_output, eps: float = 1e-10, -): +) -> Tensor: with torch.no_grad(): input, input_ref = inputs.chunk(2) output, output_ref = outputs.chunk(2) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 2380adbce7..e5ef63f2b2 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,7 @@ # pyre-strict import math -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -265,13 +265,13 @@ def attribute( is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) - additional_forward_args = _format_additional_forward_args( + formatted_additional_forward_args = _format_additional_forward_args( additional_forward_args ) num_examples = formatted_inputs[0].shape[0] # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got # `TensorOrTupleOfTensorsGeneric`. - feature_mask = _format_feature_mask(feature_mask, formatted_inputs) + formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 @@ -280,7 +280,7 @@ def attribute( if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, - feature_mask, + formatted_feature_mask, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -289,7 +289,10 @@ def attribute( # Computes initial evaluation with all features, which is compared # to each ablated result. initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, formatted_inputs, target, additional_forward_args + self.forward_func, + formatted_inputs, + target, + formatted_additional_forward_args, ) if show_progress: attr_progress.update() @@ -332,10 +335,10 @@ def attribute( ) in self._ith_input_ablation_generator( i, formatted_inputs, - additional_forward_args, + formatted_additional_forward_args, target, baselines, - feature_mask, + formatted_feature_mask, perturbations_per_eval, **kwargs, ): @@ -354,6 +357,10 @@ def attribute( if show_progress: attr_progress.update() + assert not isinstance(modified_eval, torch.Future), ( + "when use_futures is True, modified_eval should have " + f"non-Future type rather than {type(modified_eval)}" + ) total_attrib, weights = self._process_ablated_out( modified_eval, current_inputs, @@ -373,6 +380,11 @@ def attribute( if show_progress: attr_progress.close() + # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <: + # [Tensor, typing.Tuple[Tensor, ...]]]` + # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`. + # pyre-fixme[6]: In call `FeatureAblation._generate_result`, + # for 3rd positional argument, expected `bool` but got `Literal[]`. return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long @log_usage() @@ -399,11 +411,11 @@ def attribute_future( # `TensorOrTupleOfTensorsGeneric`. is_inputs_tuple = _is_tuple(inputs) formatted_inputs, baselines = _format_input_baseline(inputs, baselines) - additional_forward_args = _format_additional_forward_args( + formatted_additional_forward_args = _format_additional_forward_args( additional_forward_args ) num_examples = formatted_inputs[0].shape[0] - feature_mask = _format_feature_mask(feature_mask, formatted_inputs) + formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 @@ -412,7 +424,7 @@ def attribute_future( if show_progress: attr_progress = self._attribute_progress_setup( formatted_inputs, - feature_mask, + formatted_feature_mask, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -421,7 +433,10 @@ def attribute_future( # Computes initial evaluation with all features, which is compared # to each ablated result. initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, formatted_inputs, target, additional_forward_args + self.forward_func, + formatted_inputs, + target, + formatted_additional_forward_args, ) if show_progress: @@ -438,16 +453,17 @@ def attribute_future( ) processed_initial_eval_fut = initial_eval.then( - lambda x: self._process_initial_eval( - x.value(), + lambda initial_eval: self._process_initial_eval( + initial_eval.value(), formatted_inputs, ) ) # The will be the same amount futures as modified_eval down there, # since we cannot add up the evaluation result adhoc under async mode. - # pyre-fixme[24]: Generic type `Future` expects 1 type parameter. - all_futures: List[List[Future]] = [[] for _ in range(len(inputs))] + all_modified_eval_futures: List[ + List[Future[Tuple[List[Tensor], List[Tensor]]]] + ] = [[] for _ in range(len(inputs))] # Iterate through each feature tensor for ablation for i in range(len(formatted_inputs)): # Skip any empty input tensors @@ -462,10 +478,10 @@ def attribute_future( ) in self._ith_input_ablation_generator( i, formatted_inputs, - additional_forward_args, + formatted_additional_forward_args, target, baselines, - feature_mask, + formatted_feature_mask, perturbations_per_eval, **kwargs, ): @@ -495,7 +511,23 @@ def attribute_future( ) # Need to collect both initial eval and modified_eval - eval_futs: Future[List[Future[Tensor]]] = collect_all( + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[ + List[Tensor], + List[Tensor], + Tensor, + Tensor, + int, + dtype, + ], + Tensor, + ] + ] + ] + ] = collect_all( [ processed_initial_eval_fut, modified_eval, @@ -528,12 +560,12 @@ def attribute_future( ) ) - all_futures[i].append(ablated_out_fut) + all_modified_eval_futures[i].append(ablated_out_fut) if show_progress: attr_progress.close() - return self._generate_async_result(all_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long # pyre-fixme[3] return type must be annotated def _attribute_progress_setup( @@ -555,26 +587,28 @@ def _attribute_progress_setup( ) return attr_progress - # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[3]: Return type must be specified as type that does not contain `Any` def _ith_input_ablation_generator( self, - # pyre-fixme[2]: Parameter must be annotated. - i, - # pyre-fixme[2]: Parameter must be annotated. - inputs, - # pyre-fixme[2]: Parameter must be annotated. - additional_args, - # pyre-fixme[2]: Parameter must be annotated. - target, - # pyre-fixme[2]: Parameter must be annotated. - baselines, - # pyre-fixme[2]: Parameter must be annotated. - input_mask, - # pyre-fixme[2]: Parameter must be annotated. - perturbations_per_eval, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, - ): + i: int, + inputs: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + additional_args: Any, + target: TargetType, + baselines: BaselineType, + input_mask: Union[None, Tensor, Tuple[Tensor, ...]], + perturbations_per_eval: int, + **kwargs: Any, + ) -> Generator[ + Tuple[ + Tuple[Tensor, ...], + Any, + TargetType, + Tensor, + ], + None, + None, + ]: """ This method returns a generator of ablation perturbations of the i-th input @@ -590,9 +624,9 @@ def _ith_input_ablation_generator( else: extra_args[key] = value - input_mask = input_mask[i] if input_mask is not None else None - min_feature, num_features, input_mask = self._get_feature_range_and_mask( - inputs[i], input_mask, **extra_args + cur_input_mask = input_mask[i] if input_mask is not None else None + min_feature, num_features, cur_input_mask = self._get_feature_range_and_mask( + inputs[i], cur_input_mask, **extra_args ) num_examples = inputs[0].shape[0] perturbations_per_eval = min(perturbations_per_eval, num_features) @@ -654,12 +688,15 @@ def _ith_input_ablation_generator( # may not necessarilly be num_examples and will match the first # dimension of this tensor. current_reshaped = current_features[i].reshape( - (current_num_ablated_features, -1) + current_features[i].shape[1:] + (current_num_ablated_features, -1) + # pyre-fixme[58]: `+` is not supported for operand types + # `Tuple[int, int]` and `Size`. + + current_features[i].shape[1:] ) ablated_features, current_mask = self._construct_ablated_input( current_reshaped, - input_mask, + cur_input_mask, baseline, num_features_processed, num_features_processed + current_num_ablated_features, @@ -670,7 +707,10 @@ def _ith_input_ablation_generator( # (current_num_ablated_features * num_examples, inputs[i].shape[1:]), # which can be provided to the model as input. current_features[i] = ablated_features.reshape( - (-1,) + ablated_features.shape[2:] + (-1,) + # pyre-fixme[58]: `+` is not supported for operand types + # `Tuple[int]` and `Size`. + + ablated_features.shape[2:] ) yield tuple( current_features @@ -679,22 +719,15 @@ def _ith_input_ablation_generator( current_features[i] = original_tensor num_features_processed += current_num_ablated_features - # pyre-fixme[3]: Return type must be annotated. def _construct_ablated_input( self, - # pyre-fixme[2]: Parameter must be annotated. - expanded_input, - # pyre-fixme[2]: Parameter must be annotated. - input_mask, - # pyre-fixme[2]: Parameter must be annotated. - baseline, - # pyre-fixme[2]: Parameter must be annotated. - start_feature, - # pyre-fixme[2]: Parameter must be annotated. - end_feature, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, - ): + expanded_input: Tensor, + input_mask: Union[None, Tensor, Tuple[Tensor, ...]], + baseline: Union[None, float, Tensor], + start_feature: int, + end_feature: int, + **kwargs: Any, + ) -> Tuple[Tensor, Tensor]: r""" Ablates given expanded_input tensor with given feature mask, feature range, and baselines. expanded_input shape is (`num_features`, `num_examples`, ...) @@ -712,7 +745,10 @@ def _construct_ablated_input( thus counted towards ablations for that feature) and 0s otherwise. """ current_mask = torch.stack( - [input_mask == j for j in range(start_feature, end_feature)], dim=0 + # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor], + # Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`. + [input_mask == j for j in range(start_feature, end_feature)], # type: ignore # noqa: E501 line too long + dim=0, ).long() current_mask = current_mask.to(expanded_input.device) ablated_tensor = ( @@ -720,9 +756,12 @@ def _construct_ablated_input( ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_feature_range_and_mask(self, input, input_mask, **kwargs): + def _get_feature_range_and_mask( + self, + input: Tensor, + input_mask: Optional[Tensor], + **kwargs: Any, + ) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]: if input_mask is None: # Obtain feature mask for selected input tensor, matches size of # 1 input example, (1 x inputs[i].shape[1:]) @@ -731,14 +770,17 @@ def _get_feature_range_and_mask(self, input, input_mask, **kwargs): input[0:1].shape, ).long() return ( - torch.min(input_mask).item(), - torch.max(input_mask).item() + 1, + int(torch.min(input_mask).item()), + int(torch.max(input_mask).item() + 1), input_mask, ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_feature_counts(self, inputs, feature_mask, **kwargs): + def _get_feature_counts( + self, + inputs: TensorOrTupleOfTensorsGeneric, + feature_mask: Tuple[Tensor, ...], + **kwargs: Any, + ) -> Tuple[float, ...]: """return the numbers of input features""" if not feature_mask: return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs) @@ -752,8 +794,7 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs): for inp, mask in zip(inputs, feature_mask) ) - # pyre-fixme[2]: Parameter must be annotated. - def _parse_forward_out(self, forward_output) -> Tensor: + def _parse_forward_out(self, forward_output: Tensor) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output type assertion & conversion. @@ -819,32 +860,19 @@ def _process_initial_eval( def _process_ablated_out( self, - # pyre-fixme[2]: Parameter must be annotated. - modified_eval, - # pyre-fixme[2]: Parameter must be annotated. - current_inputs, - # pyre-fixme[2]: Parameter must be annotated. - current_mask, - # pyre-fixme[2]: Parameter must be annotated. - perturbations_per_eval, - # pyre-fixme[2]: Parameter must be annotated. - num_examples, - # pyre-fixme[2]: Parameter must be annotated. - initial_eval, - # pyre-fixme[2]: Parameter must be annotated. - flattened_initial_eval, - # pyre-fixme[2]: Parameter must be annotated. - inputs, - # pyre-fixme[2]: Parameter must be annotated. - n_outputs, - # pyre-fixme[2]: Parameter must be annotated. - total_attrib, - # pyre-fixme[2]: Parameter must be annotated. - weights, - # pyre-fixme[2]: Parameter must be annotated. - i, - # pyre-fixme[2]: Parameter must be annotated. - attrib_type, + modified_eval: Tensor, + current_inputs: Tuple[Tensor, ...], + current_mask: Tensor, + perturbations_per_eval: int, + num_examples: int, + initial_eval: Tensor, + flattened_initial_eval: Tensor, + inputs: TensorOrTupleOfTensorsGeneric, + n_outputs: int, + total_attrib: List[Tensor], + weights: List[Tensor], + i: int, + attrib_type: dtype, ) -> Tuple[List[Tensor], List[Tensor]]: modified_eval = self._parse_forward_out(modified_eval) @@ -903,16 +931,19 @@ def _generate_async_result( ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: # Each element of the 2d list contains evalutaion results for a feature # Need to add up all the results for each input - # pyre-fixme[24]: Generic type `Future` expects 1 type parameter. - accumulate_fut_list: List[Future] = [] + accumulate_fut_list: List[Future[None]] = [] total_attrib: List[Tensor] = [] weights: List[Tensor] = [] for i, fut_tuples in enumerate(futs): for fut_tuple in fut_tuples: accumulate_fut_list.append( fut_tuple.then( - lambda x, i=i: self._accumulate_for_single_input( # type: ignore # noqa: E501 line too long - total_attrib, weights, i, x.value()[0], x.value()[1] + lambda fut_tuple, i=i: self._accumulate_for_single_input( # type: ignore # noqa: E501 line too long + total_attrib, + weights, + i, + fut_tuple.value()[0], # attrib + fut_tuple.value()[1], # weight ) ) ) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 7db6ee27a0..0f5be93814 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -307,8 +307,8 @@ def attribute_future( def _construct_ablated_input( self, expanded_input: Tensor, - input_mask: Tensor, - baseline: Union[int, float, Tensor], + input_mask: Union[None, Tensor, Tuple[Tensor, ...]], + baseline: Union[None, float, Tensor], start_feature: int, end_feature: int, **kwargs: Any, @@ -327,7 +327,11 @@ def _construct_ablated_input( Since `baselines` is set to None for `FeatureAblation.attribute, this will be the zero tensor, however, it is not used. """ - assert input_mask.shape[0] == 1, ( + assert ( + input_mask is not None + and not isinstance(input_mask, tuple) + and input_mask.shape[0] == 1 + ), ( "input_mask.shape[0] != 1: pass in one mask in order to permute" "the same features for each input" ) diff --git a/captum/attr/_core/gradient_shap.py b/captum/attr/_core/gradient_shap.py index ba6335e4f8..feb1621730 100644 --- a/captum/attr/_core/gradient_shap.py +++ b/captum/attr/_core/gradient_shap.py @@ -301,10 +301,8 @@ def attribute( return attributions - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for GradientShap. """ @@ -316,15 +314,13 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs class InputBaselineXGradient(GradientAttribution): # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, forward_func: Callable, multiply_by_inputs=True) -> None: + def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None: r""" Args: @@ -437,10 +433,8 @@ def attribute( # type: ignore is_inputs_tuple, ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for InputBaseLineXGradient. """ @@ -452,8 +446,7 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/guided_backprop_deconvnet.py b/captum/attr/_core/guided_backprop_deconvnet.py index f6ceb916e6..071aea3981 100644 --- a/captum/attr/_core/guided_backprop_deconvnet.py +++ b/captum/attr/_core/guided_backprop_deconvnet.py @@ -2,7 +2,7 @@ # pyre-strict import warnings -from typing import Any, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union import torch import torch.nn.functional as F @@ -90,9 +90,8 @@ def attribute( # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, gradients) - def attribute_future( - self, - ) -> None: + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for ModifiedReluGradientAttribution. """ @@ -100,29 +99,26 @@ def attribute_future( "attribute_future is not implemented for ModifiedReluGradientAttribution" ) - # pyre-fixme[3]: Return type must be annotated. - def _register_hooks(self, module: Module): + def _register_hooks(self, module: Module) -> None: if isinstance(module, torch.nn.ReLU): hooks = _register_backward_hook(module, self._backward_hook, self) self.backward_hooks.extend(hooks) - # pyre-fixme[3]: Return type must be annotated. def _backward_hook( self, module: Module, grad_input: Union[Tensor, Tuple[Tensor, ...]], grad_output: Union[Tensor, Tuple[Tensor, ...]], - ): + ) -> Union[Tuple[Tensor], Tensor]: to_override_grads = grad_output if self.use_relu_grad_output else grad_input if isinstance(to_override_grads, tuple): return tuple( - F.relu(to_override_grad) for to_override_grad in to_override_grads + F.relu(to_override_grad) for to_override_grad in to_override_grads # type: ignore # noqa: E501 line too long ) else: return F.relu(to_override_grads) - # pyre-fixme[3]: Return type must be annotated. - def _remove_hooks(self): + def _remove_hooks(self) -> None: for hook in self.backward_hooks: hook.remove() diff --git a/captum/attr/_core/input_x_gradient.py b/captum/attr/_core/input_x_gradient.py index 0e2ab02892..86115bb03b 100644 --- a/captum/attr/_core/input_x_gradient.py +++ b/captum/attr/_core/input_x_gradient.py @@ -139,10 +139,8 @@ def attribute( # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for InputXGradient. """ @@ -151,6 +149,5 @@ def attribute_future( ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return True diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index 2badbeb257..730cfd48b4 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -329,10 +329,8 @@ def attribute( # type: ignore # <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for IntegratedGradients. """ @@ -423,6 +421,5 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/kernel_shap.py b/captum/attr/_core/kernel_shap.py index 3ce3011d25..8b6fb44cbf 100644 --- a/captum/attr/_core/kernel_shap.py +++ b/captum/attr/_core/kernel_shap.py @@ -279,7 +279,10 @@ def attribute( # type: ignore ) num_features_list = torch.arange(num_interp_features, dtype=torch.float) denom = num_features_list * (num_interp_features - num_features_list) + # pyre-fixme[58]: `/` is not supported for operand types + # `int` and `torch._tensor.Tensor`. probs = (num_interp_features - 1) / denom + # pyre-fixme[16]: `float` has no attribute `__setitem__`. probs[0] = 0.0 return self._attribute_kwargs( inputs=inputs, @@ -294,10 +297,8 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for KernelShap. """ diff --git a/captum/attr/_core/layer/layer_activation.py b/captum/attr/_core/layer/layer_activation.py index 99f1e951d6..bb4b8056ba 100644 --- a/captum/attr/_core/layer/layer_activation.py +++ b/captum/attr/_core/layer/layer_activation.py @@ -138,6 +138,5 @@ def attribute( ] @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return True diff --git a/captum/attr/_core/layer/layer_conductance.py b/captum/attr/_core/layer/layer_conductance.py index f353dbe32c..dc74a76c93 100644 --- a/captum/attr/_core/layer/layer_conductance.py +++ b/captum/attr/_core/layer/layer_conductance.py @@ -420,6 +420,5 @@ def _attribute( return _format_output(len(attributions) > 1, attributions) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return True diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index 4128570abe..50b8dc0b33 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -387,8 +387,7 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence: ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @@ -719,5 +718,5 @@ def attribute( @property # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_feature_ablation.py b/captum/attr/_core/layer/layer_feature_ablation.py index 6759a1b186..3adfc17687 100644 --- a/captum/attr/_core/layer/layer_feature_ablation.py +++ b/captum/attr/_core/layer/layer_feature_ablation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Type, Union import torch from captum._utils.common import ( @@ -311,6 +311,5 @@ def forward_hook(module, inp, out=None): return _attr @property - # pyre-fixme[3]: Return type must be annotated. - def attributor(self): + def attributor(self) -> Type[FeatureAblation]: return FeatureAblation diff --git a/captum/attr/_core/layer/layer_feature_permutation.py b/captum/attr/_core/layer/layer_feature_permutation.py index 89bae65f83..f3e185cef3 100644 --- a/captum/attr/_core/layer/layer_feature_permutation.py +++ b/captum/attr/_core/layer/layer_feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, cast, List, Tuple, Union +from typing import Any, Callable, cast, List, Tuple, Type, Union import torch from captum._utils.common import ( @@ -245,6 +245,7 @@ def forward_hook(module, inp, out=None): return _attr @property - # pyre-fixme[3]: Return type must be annotated. - def attributor(self): + def attributor( + self, + ) -> Type[FeaturePermutation]: return FeaturePermutation diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py index e02dd4cf63..171f1e9ad4 100644 --- a/captum/attr/_core/layer/layer_gradient_shap.py +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -340,8 +340,7 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @@ -508,6 +507,5 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/layer/layer_gradient_x_activation.py b/captum/attr/_core/layer/layer_gradient_x_activation.py index f8683ad41a..ebbb836555 100644 --- a/captum/attr/_core/layer/layer_gradient_x_activation.py +++ b/captum/attr/_core/layer/layer_gradient_x_activation.py @@ -69,8 +69,7 @@ def __init__( self._multiply_by_inputs = multiply_by_inputs @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @log_usage() diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index fb63df0ffb..1b10404f2d 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -569,6 +569,5 @@ def has_convergence_delta(self) -> bool: return True @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self.ig.multiplies_by_inputs diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 11c4ca25f4..f579a531dc 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -539,10 +539,8 @@ def attribute( self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) return self.interpretable_model.representation() - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for LimeBase. """ @@ -550,7 +548,6 @@ def attribute_future( "LimeBase does not support attribution of future samples." ) - # pyre-fixme[3]: Return type must be annotated. def _evaluate_batch( self, curr_model_inputs: List[TensorOrTupleOfTensorsGeneric], @@ -558,7 +555,7 @@ def _evaluate_batch( # pyre-fixme[2]: Parameter annotation cannot be `Any`. expanded_additional_args: Any, device: torch.device, - ): + ) -> Tensor: model_out = _run_forward( self.forward_func, # pyre-fixme[6]: For 1st argument expected `Sequence[Variable[TupleOrTens... @@ -579,8 +576,7 @@ def has_convergence_delta(self) -> bool: return False @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return False @@ -668,9 +664,8 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs): return default_exp_kernel -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def default_perturb_func(original_inp, **kwargs): +def default_perturb_func(original_inp, **kwargs) -> Tensor: assert ( "num_interp_features" in kwargs ), "Must provide num_interp_features to use default interpretable sampling function" @@ -683,9 +678,10 @@ def default_perturb_func(original_inp, **kwargs): return torch.bernoulli(probs).to(device=device).long() -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def construct_feature_mask(feature_mask, formatted_inputs): +def construct_feature_mask( + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]], + formatted_inputs: Tuple[Tensor, ...], +) -> Tuple[Tuple[Tensor, ...], int]: if feature_mask is None: feature_mask, num_interp_features = _construct_default_feature_mask( formatted_inputs @@ -1111,8 +1107,8 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future(self): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: return super().attribute_future() def _attribute_kwargs( # type: ignore diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index ce904469c1..d01f1a661f 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -1,7 +1,7 @@ # pyre-strict from copy import copy -from typing import Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Dict, List, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -31,27 +31,24 @@ class LLMAttributionResult: It also provides utilities to help present and plot the result in different forms. """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, seq_attr: Tensor, - token_attr: Union[Tensor, None], + token_attr: Optional[Tensor], input_tokens: List[str], output_tokens: List[str], - ): + ) -> None: self.seq_attr = seq_attr self.token_attr = token_attr self.input_tokens = input_tokens self.output_tokens = output_tokens @property - # pyre-fixme[3]: Return type must be annotated. - def seq_attr_dict(self): + def seq_attr_dict(self) -> Dict[str, Any]: return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)} # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def plot_token_attr(self, show=False): + def plot_token_attr(self, show: bool = False): """ Generate a matplotlib plot for visualising the attribution of the output tokens. @@ -62,7 +59,7 @@ def plot_token_attr(self, show=False): """ # pyre-fixme[16]: `Optional` has no attribute `cpu`. - token_attr = self.token_attr.cpu() + token_attr = self.token_attr.cpu() # type: ignore # maximum absolute attribution value # used as the boundary of normalization @@ -86,7 +83,7 @@ def plot_token_attr(self, show=False): ) # Create colorbar - cbar = ax.figure.colorbar(im, ax=ax) + cbar = ax.figure.colorbar(im, ax=ax) # type: ignore cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. @@ -120,8 +117,7 @@ def plot_token_attr(self, show=False): return fig, ax # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def plot_seq_attr(self, show=False): + def plot_seq_attr(self, show: bool = False): """ Generate a matplotlib plot for visualising the attribution of the output sequence. @@ -183,14 +179,13 @@ class LLMAttribution(Attribution): ) SUPPORTED_INPUTS = (TextTemplateInput, TextTokenInput) - # pyre-fixme[3]: Return type must be annotated. def __init__( self, attr_method: Attribution, # pyre-fixme[2]: Parameter must be annotated. tokenizer, attr_target: str = "log_prob", # TODO: support callable attr_target - ): + ) -> None: """ Args: attr_method (Attribution): Instance of a supported perturbation attribution @@ -242,7 +237,6 @@ class created with the llm model that follows huggingface style ), "attr_target should be either 'log_prob' or 'prob'" self.attr_target = attr_target - # pyre-fixme[3]: Return type must be annotated. def _forward_func( self, # pyre-fixme[2]: Parameter must be annotated. @@ -251,10 +245,9 @@ def _forward_func( inp, # pyre-fixme[2]: Parameter must be annotated. target_tokens, - # pyre-fixme[2]: Parameter must be annotated. - use_cached_outputs=False, + use_cached_outputs: bool = False, _inspect_forward=None, - ): + ) -> Union[int, Tensor]: perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) init_model_inp = perturbed_input @@ -291,10 +284,10 @@ def _forward_func( # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: target_log_probs = torch.stack( - [total_log_prob, *log_prob_list], dim=0 + [total_log_prob, *log_prob_list], dim=0 # type: ignore ).unsqueeze(0) else: - target_log_probs = total_log_prob + target_log_probs = total_log_prob # type: ignore # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int, # Tensor]`. target_probs = torch.exp(target_log_probs) @@ -435,10 +428,8 @@ def attribute( self.tokenizer.convert_ids_to_tokens(target_tokens), ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for LLMAttribution. """ @@ -503,14 +494,13 @@ class created with the llm model that follows huggingface style else next(self.model.parameters()).device ) - # pyre-fixme[3]: Return type must be annotated. def _forward_func( self, perturbed_tensor: Tensor, inp: InterpretableInput, target_tokens: Tensor, # 1D tensor of target token ids cur_target_idx: int, # current target index - ): + ) -> Tensor: perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) if cur_target_idx: @@ -544,7 +534,6 @@ def _format_model_input(self, model_input): """ return model_input.to(self.device) - # pyre-fixme[3]: Return type must be annotated. def attribute( self, inp: InterpretableInput, @@ -555,7 +544,7 @@ def attribute( gen_args: Optional[Dict] = None, # pyre-fixme[2]: Parameter must be annotated. **kwargs, - ): + ) -> LLMAttributionResult: """ Args: inp (InterpretableInput): input prompt for which attributions are computed @@ -657,10 +646,8 @@ def attribute( self.tokenizer.convert_ids_to_tokens(target_tokens), ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for LLMGradientAttribution. """ diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 8f7d9b6319..03772d7aae 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -4,7 +4,7 @@ import typing from collections import defaultdict -from typing import Any, cast, List, Tuple, Union +from typing import Any, Callable, cast, List, Tuple, Union import torch.nn as nn from captum._utils.common import ( @@ -257,10 +257,8 @@ def attribute( else: return _format_output(is_inputs_tuple, relevances) # type: ignore - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for LRP. """ diff --git a/captum/attr/_core/neuron/neuron_conductance.py b/captum/attr/_core/neuron/neuron_conductance.py index 34e66c2223..dcbc3ecebd 100644 --- a/captum/attr/_core/neuron/neuron_conductance.py +++ b/captum/attr/_core/neuron/neuron_conductance.py @@ -428,6 +428,5 @@ def _attribute( return attributions @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_deep_lift.py b/captum/attr/_core/neuron/neuron_deep_lift.py index a89c1a0aab..da70083727 100644 --- a/captum/attr/_core/neuron/neuron_deep_lift.py +++ b/captum/attr/_core/neuron/neuron_deep_lift.py @@ -247,8 +247,7 @@ def attribute( ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs @@ -475,6 +474,5 @@ def attribute( ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_gradient_shap.py b/captum/attr/_core/neuron/neuron_gradient_shap.py index 816ecb69e3..897dea197e 100644 --- a/captum/attr/_core/neuron/neuron_gradient_shap.py +++ b/captum/attr/_core/neuron/neuron_gradient_shap.py @@ -258,6 +258,5 @@ def attribute( ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/neuron/neuron_integrated_gradients.py b/captum/attr/_core/neuron/neuron_integrated_gradients.py index aebf06abce..3bef9caa0f 100644 --- a/captum/attr/_core/neuron/neuron_integrated_gradients.py +++ b/captum/attr/_core/neuron/neuron_integrated_gradients.py @@ -253,6 +253,5 @@ def attribute( ) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return self._multiply_by_inputs diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index 3a8b3fb3d7..eb34eda850 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, cast, List, Optional, Tuple, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -460,9 +460,8 @@ def update_partial_attribution_and_delta( delta, ) - def attribute_future( - self, - ) -> None: + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for NoiseTunnel. """ diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index dc36340899..298a95f8d6 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import numpy as np import torch @@ -271,10 +271,8 @@ def attribute( # type: ignore show_progress=show_progress, ) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for Occlusion. """ @@ -283,8 +281,8 @@ def attribute_future( def _construct_ablated_input( self, expanded_input: Tensor, - input_mask: Union[None, Tensor], - baseline: Union[Tensor, int, float], + input_mask: Union[None, Tensor, Tuple[Tensor, ...]], + baseline: Union[None, float, Tensor], start_feature: int, end_feature: int, **kwargs: Any, @@ -384,13 +382,17 @@ def _occlusion_mask( return padded_tensor.reshape((1,) + padded_tensor.shape) def _get_feature_range_and_mask( - self, input: Tensor, input_mask: Tensor, **kwargs: Any - ) -> Tuple[int, int, None]: + self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any + ) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]: feature_max = np.prod(kwargs["shift_counts"]) return 0, feature_max, None - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_feature_counts(self, inputs, feature_mask, **kwargs): + def _get_feature_counts( + self, + # pyre-fixme[2]: Parameter must be annotated. + inputs, + feature_mask: Tuple[Tensor, ...], + **kwargs: Any, + ) -> Tuple[int, ...]: """return the numbers of possible input features""" return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) diff --git a/captum/attr/_core/saliency.py b/captum/attr/_core/saliency.py index afacfdbdef..29205725c0 100644 --- a/captum/attr/_core/saliency.py +++ b/captum/attr/_core/saliency.py @@ -151,10 +151,8 @@ def attribute( # `Tuple[Tensor, ...]`. return _format_output(is_inputs_tuple, attributions) - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for Saliency. """ diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 5a045e8307..734d58fa9d 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -485,10 +485,8 @@ def attribute( # `Tuple[Tensor, ...]`. return formatted_attr - # pyre-fixme[3]: Return type must be annotated. - def attribute_future( - self, - ): + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. + def attribute_future(self) -> Callable: r""" This method is not implemented for ShapleyValueSampling. """ @@ -874,9 +872,9 @@ def attribute( show_progress=show_progress, ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval): + def _get_n_evaluations( + self, total_features: int, n_samples: int, perturbations_per_eval: int + ) -> int: """return the total number of forward evaluations needed""" return math.ceil(total_features / perturbations_per_eval) * math.factorial( total_features diff --git a/captum/attr/_models/base.py b/captum/attr/_models/base.py index cb4964f514..ac3eecbd51 100644 --- a/captum/attr/_models/base.py +++ b/captum/attr/_models/base.py @@ -142,9 +142,8 @@ def _get_deep_layer_name(obj, layer_names): return reduce(getattr, layer_names.split("."), obj) -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def _set_deep_layer_value(obj, layer_names, value): +def _set_deep_layer_value(obj, layer_names, value) -> None: r""" Traverses through the layer names that are separated by dot in order to access the embedding layer and update its value. diff --git a/captum/attr/_models/pytext.py b/captum/attr/_models/pytext.py index f0f27fd2e7..ff94f86f5d 100644 --- a/captum/attr/_models/pytext.py +++ b/captum/attr/_models/pytext.py @@ -2,6 +2,7 @@ # pyre-strict from collections import defaultdict +from typing import Tuple import torch from pytext.models.embeddings.dict_embedding import DictEmbedding @@ -147,9 +148,8 @@ def generate_baseline(self, integ_grads_embeddings, seq_length): ) return tuple(baseline) - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def _generate_baseline_single_word_feature(self, device): + def _generate_baseline_single_word_feature(self, device) -> torch.Tensor: return ( torch.tensor( [self.vocab_word.stoi[self.PAD] if hasattr(self, "vocab_word") else 0] @@ -158,9 +158,11 @@ def _generate_baseline_single_word_feature(self, device): .to(device) ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _generate_baseline_single_dict_feature(self, device): + def _generate_baseline_single_dict_feature( + self, + # pyre-fixme[2]: Parameter `device` has no type specified. + device, + ) -> Tuple[torch.Tensor, ...]: r"""Generate dict features based on Assistant's case study by using sia_transformer: fbcode/assistant/sia/transformer/sia_transformer.py @@ -247,9 +249,8 @@ def configure_task_integ_grads_embeddings(task): return integrated_gradients_embedding_lst[0] -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def configure_model_integ_grads_embeddings(model): +def configure_model_integ_grads_embeddings(model) -> EmbeddingList: r""" Wraps Pytext's DocNN model embedding with `IntegratedGradientsEmbedding` IntegratedGradientsEmbedding allows to perform baseline related operations diff --git a/captum/attr/_utils/approximation_methods.py b/captum/attr/_utils/approximation_methods.py index 318578277f..8debc95540 100644 --- a/captum/attr/_utils/approximation_methods.py +++ b/captum/attr/_utils/approximation_methods.py @@ -21,8 +21,7 @@ class Riemann(Enum): "riemann_trapezoid", ] -# pyre-fixme[5]: Global expression must be annotated. -SUPPORTED_METHODS = SUPPORTED_RIEMANN_METHODS + ["gausslegendre"] +SUPPORTED_METHODS: List[str] = SUPPORTED_RIEMANN_METHODS + ["gausslegendre"] def approximation_parameters( diff --git a/captum/attr/_utils/attribution.py b/captum/attr/_utils/attribution.py index 12d36e799b..c8a628cdc3 100644 --- a/captum/attr/_utils/attribution.py +++ b/captum/attr/_utils/attribution.py @@ -73,7 +73,7 @@ def __init__(self, forward_func: Callable) -> None: """ - # pyre-fixme[24] Callable needs 2 type parameters + # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. attribute_future: Callable r""" @@ -105,8 +105,7 @@ def __init__(self, forward_func: Callable) -> None: """ @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return False def has_convergence_delta(self) -> bool: @@ -361,8 +360,7 @@ def __init__(self, forward_func: Callable) -> None: Attribution.__init__(self, forward_func) @property - # pyre-fixme[3]: Return type must be annotated. - def multiplies_by_inputs(self): + def multiplies_by_inputs(self) -> bool: return True diff --git a/captum/attr/_utils/baselines.py b/captum/attr/_utils/baselines.py index 3f88efc5d1..5b347cb31c 100644 --- a/captum/attr/_utils/baselines.py +++ b/captum/attr/_utils/baselines.py @@ -20,7 +20,6 @@ class ProductBaselines: the corresponding values. """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. @@ -28,7 +27,7 @@ def __init__( List[List[Any]], Dict[Union[str, Tuple[str, ...]], List[Any]], ], - ): + ) -> None: if isinstance(baseline_values, dict): dict_keys = list(baseline_values.keys()) baseline_values = [baseline_values[k] for k in dict_keys] diff --git a/captum/attr/_utils/batching.py b/captum/attr/_utils/batching.py index 641314dc85..7ea76a0254 100644 --- a/captum/attr/_utils/batching.py +++ b/captum/attr/_utils/batching.py @@ -29,12 +29,9 @@ def _batch_attribution( num_examples, # pyre-fixme[2]: Parameter must be annotated. internal_batch_size, - # pyre-fixme[2]: Parameter must be annotated. - n_steps, - # pyre-fixme[2]: Parameter must be annotated. - include_endpoint=False, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, + n_steps: int, + include_endpoint: bool = False, + **kwargs: Any, ): """ This method applies internal batching to given attribution method, dividing diff --git a/captum/attr/_utils/class_summarizer.py b/captum/attr/_utils/class_summarizer.py index 085ba76148..5fe1deab35 100644 --- a/captum/attr/_utils/class_summarizer.py +++ b/captum/attr/_utils/class_summarizer.py @@ -32,7 +32,7 @@ def update( # type: ignore self, x: TensorOrTupleOfTensorsGeneric, labels: TargetType = None, - ): + ) -> None: r""" Updates the stats of the summarizer, optionally associated to classes. diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index fe4e995fc6..727781c171 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -196,7 +196,6 @@ class TextTemplateInput(InterpretableInput): """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. @@ -205,7 +204,7 @@ def __init__( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. baselines: Union[List[str], Dict[str, str], Callable, None] = None, mask: Union[List[int], Dict[str, int], None] = None, - ): + ) -> None: # convert values dict to list if isinstance(values, dict): dict_keys = list(values.keys()) @@ -391,7 +390,6 @@ class TextTokenInput(InterpretableInput): """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, text: str, @@ -399,7 +397,7 @@ def __init__( tokenizer, baselines: Union[int, str] = 0, # usually UNK skip_tokens: Union[List[int], List[str], None] = None, - ): + ) -> None: inp_tensor = tokenizer.encode(text, return_tensors="pt") # input tensor into the model of token ids diff --git a/captum/attr/_utils/lrp_rules.py b/captum/attr/_utils/lrp_rules.py index a638aba17b..2dd8dc4fe8 100644 --- a/captum/attr/_utils/lrp_rules.py +++ b/captum/attr/_utils/lrp_rules.py @@ -77,8 +77,7 @@ def _backward_hook_input(grad): return _backward_hook_input # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _create_backward_hook_output(self, outputs): + def _create_backward_hook_output(self, outputs: torch.Tensor): # pyre-fixme[53]: Captured variable `outputs` is not annotated. # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. @@ -92,9 +91,8 @@ def _backward_hook_output(grad): return _backward_hook_output - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def forward_hook_weights(self, module, inputs, outputs): + def forward_hook_weights(self, module, inputs, outputs) -> None: """Save initial activations a_j before modules are changed""" device = inputs[0].device if isinstance(inputs, tuple) else inputs.device if hasattr(module, "activations") and device in module.activations: @@ -135,14 +133,12 @@ class EpsilonRule(PropagationRule): discriminator during propagation. """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, epsilon=1e-9) -> None: + def __init__(self, epsilon: float = 1e-9) -> None: # pyre-fixme[4]: Attribute must be annotated. self.STABILITY_FACTOR = epsilon - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def _manipulate_weights(self, module, inputs, outputs): + def _manipulate_weights(self, module, inputs, outputs) -> None: pass @@ -158,16 +154,14 @@ class GammaRule(PropagationRule): the positive relevance is increased. """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None: + def __init__(self, gamma: float = 0.25, set_bias_to_zero: bool = False) -> None: # pyre-fixme[4]: Attribute must be annotated. self.gamma = gamma # pyre-fixme[4]: Attribute must be annotated. self.set_bias_to_zero = set_bias_to_zero - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def _manipulate_weights(self, module, inputs, outputs): + def _manipulate_weights(self, module, inputs, outputs) -> None: if hasattr(module, "weight"): module.weight.data = ( module.weight.data + self.gamma * module.weight.data.clamp(min=0) @@ -189,14 +183,12 @@ class Alpha1_Beta0_Rule(PropagationRule): Use for lower layers. """ - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, set_bias_to_zero=False) -> None: + def __init__(self, set_bias_to_zero: bool = False) -> None: # pyre-fixme[4]: Attribute must be annotated. self.set_bias_to_zero = set_bias_to_zero - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def _manipulate_weights(self, module, inputs, outputs): + def _manipulate_weights(self, module, inputs, outputs) -> None: if hasattr(module, "weight"): module.weight.data = module.weight.data.clamp(min=0) if self.set_bias_to_zero and hasattr(module, "bias"): diff --git a/captum/attr/_utils/stat.py b/captum/attr/_utils/stat.py index 70bfe47c7c..f369e61a25 100644 --- a/captum/attr/_utils/stat.py +++ b/captum/attr/_utils/stat.py @@ -37,8 +37,7 @@ def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None: self._other_stats: Optional[SummarizerSingleTensor] = None - # pyre-fixme[3]: Return type must be annotated. - def init(self): + def init(self) -> None: pass def _get_stat(self, stat: "Stat") -> Optional["Stat"]: @@ -52,8 +51,7 @@ def update(self, x: Tensor): def get(self) -> Optional[Tensor]: raise NotImplementedError() - # pyre-fixme[3]: Return type must be annotated. - def __hash__(self): + def __hash__(self) -> int: return hash((self.__class__, frozenset(self.params.items()))) def __eq__(self, other: object) -> bool: @@ -68,8 +66,7 @@ def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property - # pyre-fixme[3]: Return type must be annotated. - def name(self): + def name(self) -> str: """ The name of the statistic. i.e. it is the key in a .summary @@ -92,16 +89,15 @@ class Count(Stat): def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) - # pyre-fixme[4]: Attribute must be annotated. - self.n = None + self.n: Optional[int] = None - # pyre-fixme[3]: Return type must be annotated. - def get(self): + # pyre-fixme[15]: `captum.attr._utils.stat.Count.get` overrides method defined + # in `Stat` inconsistently. Returned type `Optional[int]` is not a subtype of + # the overridden return `Optional[torch._tensor.Tensor]`. + def get(self) -> Optional[int]: # type: ignore return self.n - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def update(self, x): + def update(self, x: Tensor) -> None: if self.n is None: self.n = 0 self.n += 1 @@ -120,16 +116,13 @@ def __init__(self, name: Optional[str] = None) -> None: def get(self) -> Optional[Tensor]: return self.rolling_mean - # pyre-fixme[3]: Return type must be annotated. - def init(self): + def init(self) -> None: # pyre-fixme[8]: Attribute has type `Optional[Count]`; used as `Optional[Stat]`. - self.n = self._get_stat(Count()) + self.n = self._get_stat(Count()) # type: ignore - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def update(self, x): + def update(self, x: Tensor) -> None: # pyre-fixme[16]: `Optional` has no attribute `get`. - n = self.n.get() + n = self.n.get() # type: ignore if self.rolling_mean is None: # Ensures rolling_mean is a float tensor @@ -152,8 +145,7 @@ def __init__(self, name: Optional[str] = None) -> None: # pyre-fixme[4]: Attribute must be annotated. self.mse = None - # pyre-fixme[3]: Return type must be annotated. - def init(self): + def init(self) -> None: # pyre-fixme[16]: `MSE` has no attribute `mean`. self.mean = self._get_stat(Mean()) @@ -162,10 +154,9 @@ def get(self) -> Optional[Tensor]: return torch.zeros_like(self.prev_mean) return self.mse - # pyre-fixme[3]: Return type must be annotated. - def update(self, x: Tensor): + def update(self, x: Tensor) -> None: # pyre-fixme[16]: `MSE` has no attribute `mean`. - mean = self.mean.get() + mean = self.mean.get() # type: ignore if mean is not None and self.prev_mean is not None: rhs = (x - self.prev_mean) * (x - mean) @@ -175,7 +166,7 @@ def update(self, x: Tensor): self.mse += rhs # do not not clone - self.prev_mean = mean.clone() + self.prev_mean = mean.clone() # type: ignore class Var(Stat): @@ -198,33 +189,31 @@ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: super().__init__(name=name, order=order) self.order = order - # pyre-fixme[3]: Return type must be annotated. - def init(self): + def init(self) -> None: # pyre-fixme[16]: `Var` has no attribute `mse`. self.mse = self._get_stat(MSE()) # pyre-fixme[16]: `Var` has no attribute `n`. self.n = self._get_stat(Count()) - # pyre-fixme[3]: Return type must be annotated. - def update(self, x: Tensor): + def update(self, x: Tensor) -> None: pass def get(self) -> Optional[Tensor]: # pyre-fixme[16]: `Var` has no attribute `mse`. - mse = self.mse.get() + mse = self.mse.get() # type: ignore # pyre-fixme[16]: `Var` has no attribute `n`. - n = self.n.get() + n = self.n.get() # type: ignore if mse is None: return None - if n <= self.order: + if n <= self.order: # type: ignore return torch.zeros_like(mse) # NOTE: The following ensures mse is a float tensor. # torch.true_divide is available in PyTorch 1.5 and later. # This is for compatibility with 1.4. - return mse.to(torch.float64) / (n - self.order) + return mse.to(torch.float64) / (n - self.order) # type: ignore class StdDev(Stat): @@ -244,18 +233,16 @@ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: super().__init__(name=name, order=order) self.order = order - # pyre-fixme[3]: Return type must be annotated. - def init(self): + def init(self) -> None: # pyre-fixme[16]: `StdDev` has no attribute `var`. self.var = self._get_stat(Var(order=self.order)) - # pyre-fixme[3]: Return type must be annotated. - def update(self, x: Tensor): + def update(self, x: Tensor) -> None: pass def get(self) -> Optional[Tensor]: # pyre-fixme[16]: `StdDev` has no attribute `var`. - var = self.var.get() + var = self.var.get() # type: ignore return var**0.5 if var is not None else None @@ -268,16 +255,13 @@ class GeneralAccumFn(Stat): # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. def __init__(self, fn: Callable, name: Optional[str] = None) -> None: super().__init__(name=name) - # pyre-fixme[4]: Attribute must be annotated. - self.result = None + self.result: Optional[Tensor] = None self.fn = fn def get(self) -> Optional[Tensor]: return self.result - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def update(self, x): + def update(self, x: Tensor) -> None: if self.result is None: self.result = x else: diff --git a/captum/attr/_utils/summarizer.py b/captum/attr/_utils/summarizer.py index 8a0944f70b..148b19787a 100644 --- a/captum/attr/_utils/summarizer.py +++ b/captum/attr/_utils/summarizer.py @@ -46,8 +46,7 @@ def _copy_stats(self): return copy.deepcopy(self._stats) - # pyre-fixme[3]: Return type must be annotated. - def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): + def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]) -> None: r""" Calls `update` on each `Stat` object within the summarizer @@ -206,8 +205,7 @@ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None: stat._other_stats = self stat.init() - # pyre-fixme[3]: Return type must be annotated. - def update(self, x: Tensor): + def update(self, x: Tensor) -> None: r""" Updates the summary of a given tensor `x` diff --git a/captum/attr/_utils/visualization.py b/captum/attr/_utils/visualization.py index 0f67043bff..e566bbb7b8 100644 --- a/captum/attr/_utils/visualization.py +++ b/captum/attr/_utils/visualization.py @@ -830,9 +830,7 @@ def __init__( self.convergence_score = convergence_score -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def _get_color(attr): +def _get_color(attr: int) -> str: # clip values to prevent CSS errors (Values should be from [-1,1]) attr = max(-1, min(1, attr)) if attr > 0: @@ -846,23 +844,19 @@ def _get_color(attr): return "hsl({}, {}%, {}%)".format(hue, sat, lig) -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def format_classname(classname): +def format_classname(classname) -> str: return '{}'.format(classname) -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def format_special_tokens(token): +def format_special_tokens(token: str) -> str: if token.startswith("<") and token.endswith(">"): return "#" + token.strip("<>") return token -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def format_tooltip(item, text): +def format_tooltip(item, text) -> str: return '
{item}\ {text}\
'.format( @@ -870,9 +864,8 @@ def format_tooltip(item, text): ) -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def format_word_importances(words, importances): +def format_word_importances(words, importances) -> str: if importances is None or len(importances) == 0: return "" assert len(words) <= len(importances) diff --git a/captum/concept/_core/cav.py b/captum/concept/_core/cav.py index 771bce2dc0..97f6908eb4 100644 --- a/captum/concept/_core/cav.py +++ b/captum/concept/_core/cav.py @@ -89,8 +89,7 @@ def assemble_save_path( file_name = concepts_to_str(concepts) + "-" + layer + ".pkl" return os.path.join(path, model_id, file_name) - # pyre-fixme[3]: Return type must be annotated. - def save(self): + def save(self) -> None: r""" Saves a dictionary of the CAV computed values into a pickle file in the location returned by the "assemble_save_path" static methods. The @@ -137,8 +136,9 @@ def create_cav_dir_if_missing(save_path: str, model_id: str) -> None: os.makedirs(cav_model_id_path) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - def load(cavs_path: str, model_id: str, concepts: List[Concept], layer: str): + def load( + cavs_path: str, model_id: str, concepts: List[Concept], layer: str + ) -> Optional["CAV"]: r""" Loads CAV dictionary from a pickle file for given input `layer` and `concepts`. diff --git a/captum/concept/_core/tcav.py b/captum/concept/_core/tcav.py index bdd8cf56b0..a81c500794 100644 --- a/captum/concept/_core/tcav.py +++ b/captum/concept/_core/tcav.py @@ -55,18 +55,13 @@ def __init__(self, datasets: List[AV.AVDataset], labels: List[int]) -> None: from itertools import accumulate offsets = [0] + list(accumulate(map(len, datasets), (lambda x, y: x + y))) - # pyre-fixme[4]: Attribute must be annotated. - self.length = offsets[-1] + self.length: int = offsets[-1] self.datasets = datasets self.labels = labels - # pyre-fixme[4]: Attribute must be annotated. - self.lowers = offsets[:-1] - # pyre-fixme[4]: Attribute must be annotated. - self.uppers = offsets[1:] + self.lowers: List[int] = offsets[:-1] + self.uppers: List[int] = offsets[1:] - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _i_to_k(self, i): + def _i_to_k(self, i: int) -> int: left, right = 0, len(self.uppers) while left < right: @@ -77,9 +72,9 @@ def _i_to_k(self, i): left = mid else: right = mid + return -1 - # pyre-fixme[3]: Return type must be annotated. - def __getitem__(self, i: int): + def __getitem__(self, i: int) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: """ Returns a batch of activation vectors, as well as a batch of labels indicating which concept the batch of activation vectors is associated @@ -99,13 +94,13 @@ def __getitem__(self, i: int): inputs = self.datasets[k][i - self.lowers[k]] # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no # attribute `shape`. - assert len(inputs.shape) == 2 + assert len(inputs.shape) == 2 # type: ignore # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no # attribute `size`. # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no # attribute `device`. - labels = torch.tensor([self.labels[k]] * inputs.size(0), device=inputs.device) + labels = torch.tensor([self.labels[k]] * inputs.size(0), device=inputs.device) # type: ignore # noqa: E501 line too long return inputs, labels def __len__(self) -> int: @@ -116,16 +111,13 @@ def __len__(self) -> int: def train_cav( - # pyre-fixme[2]: Parameter must be annotated. - model_id, + model_id: str, concepts: List[Concept], layers: Union[str, List[str]], classifier: Classifier, save_path: str, - # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use - # `typing.Dict[, ]` to avoid runtime subscripting errors. - classifier_kwargs: Dict, -) -> Dict[str, Dict[str, CAV]]: + classifier_kwargs: Dict[str, Any], +) -> Dict[str, Dict[str, Optional[CAV]]]: r""" A helper function for parallel CAV computations that can be called from a python process. @@ -163,7 +155,7 @@ def train_cav( """ concepts_key = concepts_to_str(concepts) - cavs: Dict[str, Dict[str, CAV]] = defaultdict() + cavs: Dict[str, Dict[str, Optional[CAV]]] = defaultdict() cavs[concepts_key] = defaultdict() layers = [layers] if isinstance(layers, str) else layers for layer in layers: @@ -176,12 +168,10 @@ def train_cav( labels = [concept.id for concept in concepts] - # pyre-fixme[22]: The cast is redundant. - labelled_dataset = LabelledDataset(cast(List[AV.AVDataset], datasets), labels) + labelled_dataset = LabelledDataset(datasets, labels) - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def batch_collate(batch): + def batch_collate(batch) -> Tuple[Tensor, Tensor]: inputs, labels = zip(*batch) return torch.cat(inputs), torch.cat(labels) @@ -217,7 +207,8 @@ def batch_collate(batch): model_id, ) # Saving cavs on the disk - cavs[concepts_key][layer].save() + # pyre-fixme[16]: `Optional` has no attribute `save`. + cavs[concepts_key][layer].save() # type: ignore return cavs @@ -269,8 +260,7 @@ def __init__( model_id: str = "default_model_id", classifier: Optional[Classifier] = None, layer_attr_method: Optional[LayerAttribution] = None, - # pyre-fixme[2]: Parameter must be annotated. - attribute_to_layer_input=False, + attribute_to_layer_input: bool = False, save_path: str = "./cav/", **classifier_kwargs: Any, ) -> None: @@ -321,21 +311,20 @@ def __init__( For more thorough examples, please check out TCAV tutorial and test cases. """ ConceptInterpreter.__init__(self, model) - # pyre-fixme[4]: Attribute must be annotated. - self.layers = [layers] if isinstance(layers, str) else layers + self.layers: List[str] = [layers] if isinstance(layers, str) else layers self.model_id = model_id self.concepts: Set[Concept] = set() self.classifier = classifier - # pyre-fixme[4]: Attribute must be annotated. - self.classifier_kwargs = classifier_kwargs + # pyre-fixme[4]: Attribute `classifier_kwargs` of class `TCAV` + # must have a type other than `Any`. + self.classifier_kwargs: Any = classifier_kwargs # pyre-fixme[8]: Attribute has type `Dict[str, Dict[str, CAV]]`; used as # `DefaultDict[Variable[_KT], DefaultDict[Variable[_KT], Variable[_VT]]]`. self.cavs: Dict[str, Dict[str, CAV]] = defaultdict(lambda: defaultdict()) if self.classifier is None: self.classifier = DefaultClassifier() if layer_attr_method is None: - # pyre-fixme[4]: Attribute must be annotated. - self.layer_attr_method = cast( + self.layer_attr_method: LayerAttribution = cast( LayerAttribution, LayerGradientXActivation( # type: ignore model, @@ -354,8 +343,7 @@ def __init__( "will use `default_model_id` as its default value." ) - # pyre-fixme[4]: Attribute must be annotated. - self.attribute_to_layer_input = attribute_to_layer_input + self.attribute_to_layer_input: bool = attribute_to_layer_input self.save_path = save_path # Creates CAV save directory if it doesn't exist. It is created once in the @@ -372,9 +360,9 @@ def generate_all_activations(self) -> None: for concept in self.concepts: self.generate_activation(self.layers, concept) - # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use - # `typing.List[]` to avoid runtime subscripting errors. - def generate_activation(self, layers: Union[str, List], concept: Concept) -> None: + def generate_activation( + self, layers: Union[str, List[str]], concept: Concept + ) -> None: r""" Computes layer activations for the specified `concept` and the list of layer(s) `layers`. @@ -390,13 +378,12 @@ def generate_activation(self, layers: Union[str, List], concept: Concept) -> Non layer_modules = [_get_module_from_name(self.model, layer) for layer in layers] layer_act = LayerActivation(self.model, layer_modules) - assert concept.data_iter is not None, ( + data_iter = concept.data_iter + assert data_iter is not None, ( "Data iterator for concept id:", "{} must be specified".format(concept.id), ) - # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got - # `Optional[DataLoader[typing.Any]]`. - for i, examples in enumerate(concept.data_iter): + for i, examples in enumerate(data_iter): activations = layer_act.attribute.__wrapped__( # type: ignore layer_act, examples, @@ -461,9 +448,10 @@ def load_cavs( concept_layers = defaultdict(list) for layer in self.layers: - self.cavs[concepts_key][layer] = CAV.load( - self.save_path, self.model_id, concepts, layer - ) + cav = CAV.load(self.save_path, self.model_id, concepts, layer) + + if cav is not None: + self.cavs[concepts_key][layer] = cav # If CAV aren't loaded if ( @@ -482,13 +470,12 @@ def load_cavs( concept_layers[concept].append(layer) return layers, concept_layers - # pyre-fixme[3]: Return type must be annotated. def compute_cavs( self, experimental_sets: List[List[Concept]], force_train: bool = False, processes: Optional[int] = None, - ): + ) -> Dict[str, Dict[str, CAV]]: r""" This method computes CAVs for given `experiments_sets` and layers specified in `self.layers` instance variable. Internally, it @@ -592,7 +579,7 @@ def compute_cavs( # list[Dict[concept, Dict[layer, list]]] => Dict[concept, Dict[layer, list]] for cavs in cavs_list: for c_key in cavs: - self.cavs[c_key].update(cavs[c_key]) + self.cavs[c_key].update(cavs[c_key]) # type: ignore return self.cavs diff --git a/captum/insights/attr_vis/config.py b/captum/insights/attr_vis/config.py index cf02c1fdf6..5acb916b27 100644 --- a/captum/insights/attr_vis/config.py +++ b/captum/insights/attr_vis/config.py @@ -74,9 +74,7 @@ class ConfigParameters(NamedTuple): } -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def _str_to_tuple(s): +def _str_to_tuple(s: Tuple[int, ...]) -> Tuple[int, ...]: if isinstance(s, tuple): return s return tuple([int(i) for i in s.split()]) diff --git a/captum/log/__init__.py b/captum/log/__init__.py index 35b9a20e59..82e851c14e 100644 --- a/captum/log/__init__.py +++ b/captum/log/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # pyre-strict +from typing import Any try: from captum.log.fb.internal_log import ( @@ -23,9 +24,7 @@ except ImportError: from functools import wraps - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def log(*args, **kwargs): + def log(*args: Any, **kwargs: Any) -> None: pass # bug with mypy: https://github.com/python/mypy/issues/1153 @@ -34,42 +33,35 @@ class TimedLog: # type: ignore def __init__(self, *args, **kwargs) -> None: pass - # pyre-fixme[3]: Return type must be annotated. - def __enter__(self): + def __enter__(self) -> "TimedLog": return self - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def __exit__(self, exception_type, exception_value, traceback): + def __exit__(self, exception_type, exception_value, traceback) -> bool: return exception_value is not None # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def log_usage(*log_args, **log_kwargs): + def log_usage(*log_args: Any, **log_kwargs: Any): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def _log_usage(func): @wraps(func) # pyre-fixme[53]: Captured variable `func` is not annotated. # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any): return func(*args, **kwargs) return wrapper return _log_usage - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def set_environment(env): + def set_environment(env) -> None: pass - # pyre-fixme[3]: Return type must be annotated. - def disable_detailed_logging(): + def disable_detailed_logging() -> None: pass - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def patch_methods(tester, patch_log=True): + def patch_methods(tester, patch_log: bool = True) -> None: pass diff --git a/captum/module/binary_concrete_stochastic_gates.py b/captum/module/binary_concrete_stochastic_gates.py index 7167026f6c..0d7e7f759a 100644 --- a/captum/module/binary_concrete_stochastic_gates.py +++ b/captum/module/binary_concrete_stochastic_gates.py @@ -2,7 +2,7 @@ # pyre-strict import math -from typing import Optional +from typing import Any, Optional import torch from captum.module.stochastic_gates_base import StochasticGatesBase @@ -18,16 +18,6 @@ def _torch_empty(batch_size: int, n_gates: int, device: torch.device) -> Tensor: torch.fx.wrap(_torch_empty) -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter must be annotated. -def _logit(inp): - # torch.logit is introduced in 1.7.0 - if hasattr(torch, "logit"): - return torch.logit(inp) - else: - return torch.log(inp) - torch.log(1 - inp) - - class BinaryConcreteStochasticGates(StochasticGatesBase): """ Stochastic Gates with binary concrete distribution. @@ -63,7 +53,6 @@ class BinaryConcreteStochasticGates(StochasticGatesBase): """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, n_gates: int, @@ -74,7 +63,7 @@ def __init__( upper_bound: float = 1.1, eps: float = 1e-8, reg_reduction: str = "sum", - ): + ) -> None: """ Args: n_gates (int): number of gates. @@ -163,7 +152,9 @@ def _sample_gate_values(self, batch_size: int) -> Tensor: batch_size, self.n_gates, device=self.log_alpha_param.device ) u.uniform_(self.eps, 1 - self.eps) - s = torch.sigmoid((_logit(u) + self.log_alpha_param) / self.temperature) + s = torch.sigmoid( + (torch.logit(u) + self.log_alpha_param) / self.temperature + ) else: s = torch.sigmoid(self.log_alpha_param) @@ -199,9 +190,9 @@ def _get_gate_active_probs(self) -> Tensor: return torch.sigmoid(self.log_alpha_param - self.active_prob_offset) @classmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs): + def _from_pretrained( + cls, log_alpha_param: Tensor, *args: Any, **kwargs: Any + ) -> "BinaryConcreteStochasticGates": """ Private factory method to create an instance with pretrained parameters diff --git a/captum/module/gaussian_stochastic_gates.py b/captum/module/gaussian_stochastic_gates.py index 0335f7be1e..18bffe732d 100644 --- a/captum/module/gaussian_stochastic_gates.py +++ b/captum/module/gaussian_stochastic_gates.py @@ -2,7 +2,7 @@ # pyre-strict import math -from typing import Optional +from typing import Any, Optional import torch from captum.module.stochastic_gates_base import StochasticGatesBase @@ -41,7 +41,6 @@ class GaussianStochasticGates(StochasticGatesBase): >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, n_gates: int, @@ -49,7 +48,7 @@ def __init__( reg_weight: Optional[float] = 1.0, std: Optional[float] = 0.5, reg_reduction: str = "sum", - ): + ) -> None: """ Args: n_gates (int): number of gates. @@ -138,9 +137,9 @@ def _get_gate_active_probs(self) -> Tensor: return 0.5 * (1 + torch.erf(x / math.sqrt(2))) @classmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _from_pretrained(cls, mu: Tensor, *args, **kwargs): + def _from_pretrained( + cls, mu: Tensor, *args: Any, **kwargs: Any + ) -> "GaussianStochasticGates": """ Private factory method to create an instance with pretrained parameters diff --git a/captum/module/stochastic_gates_base.py b/captum/module/stochastic_gates_base.py index aee8782704..b34a4d5f4d 100644 --- a/captum/module/stochastic_gates_base.py +++ b/captum/module/stochastic_gates_base.py @@ -30,14 +30,13 @@ class StochasticGatesBase(Module, ABC): extend this class and implement the distribution specific functions. """ - # pyre-fixme[3]: Return type must be annotated. def __init__( self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0, reg_reduction: str = "sum", - ): + ) -> None: """ Args: n_gates (int): number of gates. diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 0e4676aea1..892c6733bb 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -396,12 +396,14 @@ def test_llm_attr(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - self.assertEqual(res.token_attr.shape, (5, 4)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 4)) # type: ignore self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(token_attr.device.type, self.device) # type: ignore def test_llm_attr_without_target(self) -> None: llm = DummyLLM() @@ -414,12 +416,14 @@ def test_llm_attr_without_target(self) -> None: res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}) self.assertEqual(res.seq_attr.shape, (4,)) - self.assertEqual(res.token_attr.shape, (3, 4)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (3, 4)) # type: ignore self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(token_attr.device.type, self.device) # type: ignore def test_llm_attr_with_skip_tokens(self) -> None: llm = DummyLLM() @@ -433,9 +437,11 @@ def test_llm_attr_with_skip_tokens(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (3,)) - self.assertEqual(res.token_attr.shape, (5, 3)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 3)) # type: ignore self.assertEqual(res.input_tokens, ["a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(token_attr.device.type, self.device) # type: ignore diff --git a/tests/attr/test_llm_attr_gpu.py b/tests/attr/test_llm_attr_gpu.py index 790e896d55..51b5935b85 100644 --- a/tests/attr/test_llm_attr_gpu.py +++ b/tests/attr/test_llm_attr_gpu.py @@ -247,12 +247,14 @@ def test_llm_attr(self) -> None: res = llm_attr.attribute(inp, "m n o p q") # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - self.assertEqual(res.token_attr.shape, (5, 4)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 4)) # type: ignore self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(token_attr.device.type, self.device) # type: ignore def test_llm_attr_without_target(self) -> None: llm = DummyLLM() @@ -265,12 +267,14 @@ def test_llm_attr_without_target(self) -> None: res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}) self.assertEqual(res.seq_attr.shape, (4,)) - self.assertEqual(res.token_attr.shape, (3, 4)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (3, 4)) # type: ignore self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) - self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(res.seq_attr.device.type, self.device) # type: ignore + self.assertEqual(token_attr.device.type, self.device) # type: ignore def test_llm_attr_with_skip_tokens(self) -> None: llm = DummyLLM() @@ -284,9 +288,11 @@ def test_llm_attr_with_skip_tokens(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (3,)) - self.assertEqual(res.token_attr.shape, (5, 3)) + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 3)) # type: ignore self.assertEqual(res.input_tokens, ["a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - self.assertEqual(res.token_attr.device.type, self.device) + self.assertEqual(token_attr.device.type, self.device) # type: ignore diff --git a/tests/concept/test_tcav.py b/tests/concept/test_tcav.py index 66df5fdde4..679dcd10dd 100644 --- a/tests/concept/test_tcav.py +++ b/tests/concept/test_tcav.py @@ -735,26 +735,22 @@ def test_compute_cav_repeating_concept_names(self) -> None: self.assertEqual(cavs["0-1"]["conv1"].concepts[1].id, 1) self.assertEqual(cavs["0-1"]["conv1"].concepts[1].name, "random") - self.assertEqual(cavs["0-1"]["conv1"].stats["classes"], [0, 1]) - self.assertAlmostEqual( - cavs["0-1"]["conv1"].stats["accs"].item(), 0.4848, delta=0.001 - ) - self.assertEqual( - list(cavs["0-1"]["conv1"].stats["weights"].shape), [2, 128] - ) + stats = cavs["0-1"]["conv1"].stats + self.assertIsNotNone(stats) + self.assertEqual(stats["classes"], [0, 1]) # type: ignore + self.assertAlmostEqual(stats["accs"].item(), 0.4848, delta=0.001) # type: ignore # noqa: E501 line too long + self.assertEqual(list(stats["weights"].shape), [2, 128]) # type: ignore self.assertEqual(cavs["2-3"]["conv1"].concepts[0].id, 2) self.assertEqual(cavs["2-3"]["conv1"].concepts[0].name, "ceo") self.assertEqual(cavs["2-3"]["conv1"].concepts[1].id, 3) self.assertEqual(cavs["2-3"]["conv1"].concepts[1].name, "striped") - self.assertEqual(cavs["2-3"]["conv1"].stats["classes"], [2, 3]) - self.assertAlmostEqual( - cavs["2-3"]["conv1"].stats["accs"].item(), 0.4848, delta=0.001 - ) - self.assertEqual( - list(cavs["2-3"]["conv1"].stats["weights"].shape), [2, 128] - ) + stats = cavs["2-3"]["conv1"].stats + self.assertIsNotNone(stats) + self.assertEqual(stats["classes"], [2, 3]) # type: ignore + self.assertAlmostEqual(stats["accs"].item(), 0.4848, delta=0.001) # type: ignore # noqa: E501 line too long + self.assertEqual(list(stats["weights"].shape), [2, 128]) # type: ignore def compute_cavs_interpret( self, diff --git a/tests/helpers/basic_models.py b/tests/helpers/basic_models.py index 784444b761..5d1a1132b8 100644 --- a/tests/helpers/basic_models.py +++ b/tests/helpers/basic_models.py @@ -198,9 +198,8 @@ def __init__(self) -> None: self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def forward(self, x1, x2, x3: int = 2): + def forward(self, x1, x2, x3: int = 2) -> int: return 2 * self.relu1(x1) + x3 * self.relu2(x2 - 1.5) @@ -284,9 +283,8 @@ def __init__(self) -> None: self.tanh1 = nn.Tanh() self.tanh2 = nn.Tanh() - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def forward(self, x1, x2): + def forward(self, x1, x2) -> int: return 2 * self.tanh1(x1) + 2 * self.tanh2(x2 - 1.5)