66import typing
77import warnings
88from collections .abc import Iterator
9- from typing import Any , Callable , cast , List , Optional , Tuple , Union
9+ from typing import Any , Callable , cast , Generator , List , Literal , Optional , Tuple , Union
1010
1111import torch
1212from captum ._utils .common import (
2323from captum ._utils .models .linear_model import SkLearnLasso
2424from captum ._utils .models .model import Model
2525from captum ._utils .progress import progress
26- from captum ._utils .typing import (
27- BaselineType ,
28- Literal ,
29- TargetType ,
30- TensorOrTupleOfTensorsGeneric ,
31- )
26+ from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
3227from captum .attr ._utils .attribution import PerturbationAttribution
3328from captum .attr ._utils .batching import _batch_example_iterator
3429from captum .attr ._utils .common import (
@@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution):
7368
7469 def __init__ (
7570 self ,
76- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
77- forward_func : Callable ,
71+ forward_func : Callable [..., Tensor ],
7872 interpretable_model : Model ,
79- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
80- similarity_func : Callable ,
81- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
82- perturb_func : Callable ,
73+ similarity_func : Callable [
74+ ...,
75+ Union [float , Tensor ],
76+ ],
77+ perturb_func : Callable [..., object ],
8378 perturb_interpretable_space : bool ,
84- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
85- from_interp_rep_transform : Optional [ Callable ],
86- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
87- to_interp_rep_transform : Optional [Callable ],
79+ from_interp_rep_transform : Optional [
80+ Callable [..., Union [ Tensor , Tuple [ Tensor , ...]]]
81+ ],
82+ to_interp_rep_transform : Optional [Callable [..., Tensor ] ],
8883 ) -> None :
8984 r"""
9085
@@ -249,13 +244,11 @@ def attribute(
249244 self ,
250245 inputs : TensorOrTupleOfTensorsGeneric ,
251246 target : TargetType = None ,
252- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
253- additional_forward_args : Any = None ,
247+ additional_forward_args : object = None ,
254248 n_samples : int = 50 ,
255249 perturbations_per_eval : int = 1 ,
256250 show_progress : bool = False ,
257- # pyre-fixme[2]: Parameter must be annotated.
258- ** kwargs ,
251+ ** kwargs : object ,
259252 ) -> Tensor :
260253 r"""
261254 This method attributes the output of the model with given target index
@@ -551,7 +544,7 @@ def generate_perturbation() -> (
551544 curr_sample , inputs , ** kwargs
552545 )
553546
554- return interpretable_inp , curr_model_input
547+ return interpretable_inp , curr_model_input # type: ignore
555548
556549 return generate_perturbation
557550
@@ -568,8 +561,7 @@ def _evaluate_batch(
568561 self ,
569562 curr_model_inputs : List [TensorOrTupleOfTensorsGeneric ],
570563 expanded_target : TargetType ,
571- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
572- expanded_additional_args : Any ,
564+ expanded_additional_args : object ,
573565 device : torch .device ,
574566 ) -> Tensor :
575567 model_out = _run_forward (
@@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
630622def get_exp_kernel_similarity_function (
631623 distance_mode : str = "cosine" ,
632624 kernel_width : float = 1.0 ,
633- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
634- ) -> Callable :
625+ ) -> Callable [..., float ]:
635626 r"""
636627 This method constructs an appropriate similarity function to compute
637628 weights for perturbed sample in LIME. Distance between the original
@@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
680671 return default_exp_kernel
681672
682673
683- # pyre-fixme[2]: Parameter must be annotated.
684- def default_perturb_func (original_inp , ** kwargs ) -> Tensor :
674+ def default_perturb_func (
675+ original_inp : TensorOrTupleOfTensorsGeneric , ** kwargs : object
676+ ) -> Tensor :
685677 assert (
686678 "num_interp_features" in kwargs
687679 ), "Must provide num_interp_features to use default interpretable sampling function"
@@ -690,25 +682,25 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor:
690682 else :
691683 device = original_inp [0 ].device
692684
693- probs = torch .ones (1 , kwargs ["num_interp_features" ]) * 0.5
685+ probs = torch .ones (1 , cast ( int , kwargs ["num_interp_features" ]) ) * 0.5
694686 return torch .bernoulli (probs ).to (device = device ).long ()
695687
696688
697689def construct_feature_mask (
698690 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]],
699691 formatted_inputs : Tuple [Tensor , ...],
700692) -> Tuple [Tuple [Tensor , ...], int ]:
693+ feature_mask_tuple : Tuple [Tensor , ...]
701694 if feature_mask is None :
702- feature_mask , num_interp_features = _construct_default_feature_mask (
695+ feature_mask_tuple , num_interp_features = _construct_default_feature_mask (
703696 formatted_inputs
704697 )
705698 else :
706- feature_mask = _format_tensor_into_tuples (feature_mask )
699+ feature_mask_tuple = _format_tensor_into_tuples (feature_mask )
707700 min_interp_features = int (
708701 min (
709702 torch .min (single_mask ).item ()
710- # pyre-fixme[16]: `None` has no attribute `__iter__`.
711- for single_mask in feature_mask
703+ for single_mask in feature_mask_tuple
712704 if single_mask .numel ()
713705 )
714706 )
@@ -718,14 +710,12 @@ def construct_feature_mask(
718710 " start at 0." ,
719711 stacklevel = 2 ,
720712 )
721- feature_mask = tuple (
722- single_mask - min_interp_features for single_mask in feature_mask
713+ feature_mask_tuple = tuple (
714+ single_mask - min_interp_features for single_mask in feature_mask_tuple
723715 )
724716
725- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
726- # `Optional[typing.Tuple[typing.Any, ...]]`.
727- num_interp_features = _get_max_feature_index (feature_mask ) + 1
728- return feature_mask , num_interp_features
717+ num_interp_features = _get_max_feature_index (feature_mask_tuple ) + 1
718+ return feature_mask_tuple , num_interp_features
729719
730720
731721class Lime (LimeBase ):
@@ -766,8 +756,7 @@ class Lime(LimeBase):
766756
767757 def __init__ (
768758 self ,
769- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
770- forward_func : Callable ,
759+ forward_func : Callable [..., Tensor ],
771760 interpretable_model : Optional [Model ] = None ,
772761 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
773762 similarity_func : Optional [Callable ] = None ,
@@ -887,8 +876,7 @@ def attribute( # type: ignore
887876 inputs : TensorOrTupleOfTensorsGeneric ,
888877 baselines : BaselineType = None ,
889878 target : TargetType = None ,
890- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
891- additional_forward_args : Any = None ,
879+ additional_forward_args : object = None ,
892880 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
893881 n_samples : int = 25 ,
894882 perturbations_per_eval : int = 1 ,
@@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore
11331121 inputs : TensorOrTupleOfTensorsGeneric ,
11341122 baselines : BaselineType = None ,
11351123 target : TargetType = None ,
1136- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
1137- additional_forward_args : Any = None ,
1124+ additional_forward_args : object = None ,
11381125 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
11391126 n_samples : int = 25 ,
11401127 perturbations_per_eval : int = 1 ,
11411128 return_input_shape : bool = True ,
11421129 show_progress : bool = False ,
1143- # pyre-fixme[2]: Parameter must be annotated.
1144- ** kwargs ,
1130+ ** kwargs : object ,
11451131 ) -> TensorOrTupleOfTensorsGeneric :
1146- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
1147- # `TensorOrTupleOfTensorsGeneric`.
11481132 is_inputs_tuple = _is_tuple (inputs )
11491133 formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
11501134 bsz = formatted_inputs [0 ].shape [0 ]
@@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore
12631247 return coefs
12641248
12651249 @typing .overload
1266- # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
1267- # all possible arguments of overload defined on line `1201`.
12681250 def _convert_output_shape (
12691251 self ,
12701252 formatted_inp : Tuple [Tensor , ...],
12711253 feature_mask : Tuple [Tensor , ...],
12721254 coefs : Tensor ,
12731255 num_interp_features : int ,
1274- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
1275- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
12761256 is_inputs_tuple : Literal [True ],
12771257 ) -> Tuple [Tensor , ...]: ...
12781258
12791259 @typing .overload
1280- # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
1281- # all possible arguments of overload defined on line `1211`.
12821260 def _convert_output_shape ( # type: ignore
12831261 self ,
12841262 formatted_inp : Tuple [Tensor , ...],
12851263 feature_mask : Tuple [Tensor , ...],
12861264 coefs : Tensor ,
12871265 num_interp_features : int ,
1288- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
1289- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
12901266 is_inputs_tuple : Literal [False ],
12911267 ) -> Tensor : ...
12921268
1269+ @typing .overload
1270+ def _convert_output_shape (
1271+ self ,
1272+ formatted_inp : Tuple [Tensor , ...],
1273+ feature_mask : Tuple [Tensor , ...],
1274+ coefs : Tensor ,
1275+ num_interp_features : int ,
1276+ is_inputs_tuple : bool ,
1277+ ) -> Union [Tensor , Tuple [Tensor , ...]]: ...
1278+
12931279 def _convert_output_shape (
12941280 self ,
12951281 formatted_inp : Tuple [Tensor , ...],
0 commit comments