22
33# pyre-strict
44import typing
5- from typing import Any , Callable , List , Tuple , Union
5+ from typing import Any , Callable , List , Literal , Tuple , Union
66
77import torch
88from captum ._utils .common import (
1212 _format_output ,
1313 _is_tuple ,
1414)
15- from captum ._utils .typing import (
16- BaselineType ,
17- Literal ,
18- TargetType ,
19- TensorOrTupleOfTensorsGeneric ,
20- )
15+ from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
2116from captum .attr ._utils .approximation_methods import approximation_parameters
2217from captum .attr ._utils .attribution import GradientAttribution
2318from captum .attr ._utils .batching import _batch_attribution
@@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution):
4944
5045 def __init__ (
5146 self ,
52- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
53- forward_func : Callable ,
47+ forward_func : Callable [..., Tensor ],
5448 multiply_by_inputs : bool = True ,
5549 ) -> None :
5650 r"""
@@ -80,21 +74,16 @@ def __init__(
8074 # and when return_convergence_delta is True, the return type is
8175 # a tuple with both attributions and deltas.
8276 @typing .overload
83- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
84- # arguments of overload defined on line `95`.
8577 def attribute (
8678 self ,
8779 inputs : TensorOrTupleOfTensorsGeneric ,
8880 baselines : BaselineType = None ,
8981 target : TargetType = None ,
90- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
91- additional_forward_args : Any = None ,
82+ additional_forward_args : object = None ,
9283 n_steps : int = 50 ,
9384 method : str = "gausslegendre" ,
9485 internal_batch_size : Union [None , int ] = None ,
9586 * ,
96- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
97- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
9887 return_convergence_delta : Literal [True ],
9988 ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]: ...
10089
@@ -111,9 +100,6 @@ def attribute(
111100 n_steps : int = 50 ,
112101 method : str = "gausslegendre" ,
113102 internal_batch_size : Union [None , int ] = None ,
114- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
115- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
116- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
117103 return_convergence_delta : Literal [False ] = False ,
118104 ) -> TensorOrTupleOfTensorsGeneric : ...
119105
@@ -275,37 +261,35 @@ def attribute( # type: ignore
275261 """
276262 # Keeps track whether original input is a tuple or not before
277263 # converting it into a tuple.
278- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
279- # `TensorOrTupleOfTensorsGeneric`.
280264 is_inputs_tuple = _is_tuple (inputs )
281265
282266 # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
283267 # `Tuple[Tensor, ...]`.
284- inputs , baselines = _format_input_baseline (inputs , baselines )
268+ formatted_inputs , formatted_baselines = _format_input_baseline (
269+ inputs , baselines
270+ )
285271
286272 # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
287273 # `TensorOrTupleOfTensorsGeneric`.
288- _validate_input (inputs , baselines , n_steps , method )
274+ _validate_input (formatted_inputs , formatted_baselines , n_steps , method )
289275
290276 if internal_batch_size is not None :
291- num_examples = inputs [0 ].shape [0 ]
277+ num_examples = formatted_inputs [0 ].shape [0 ]
292278 attributions = _batch_attribution (
293279 self ,
294280 num_examples ,
295281 internal_batch_size ,
296282 n_steps ,
297- inputs = inputs ,
298- baselines = baselines ,
283+ inputs = formatted_inputs ,
284+ baselines = formatted_baselines ,
299285 target = target ,
300286 additional_forward_args = additional_forward_args ,
301287 method = method ,
302288 )
303289 else :
304290 attributions = self ._attribute (
305- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
306- # got `TensorOrTupleOfTensorsGeneric`.
307- inputs = inputs ,
308- baselines = baselines ,
291+ inputs = formatted_inputs ,
292+ baselines = formatted_baselines ,
309293 target = target ,
310294 additional_forward_args = additional_forward_args ,
311295 n_steps = n_steps ,
@@ -344,8 +328,7 @@ def _attribute(
344328 inputs : Tuple [Tensor , ...],
345329 baselines : Tuple [Union [Tensor , int , float ], ...],
346330 target : TargetType = None ,
347- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
348- additional_forward_args : Any = None ,
331+ additional_forward_args : object = None ,
349332 n_steps : int = 50 ,
350333 method : str = "gausslegendre" ,
351334 step_sizes_and_alphas : Union [None , Tuple [List [float ], List [float ]]] = None ,
0 commit comments