33# pyre-strict
44
55import typing
6- from typing import Callable , cast , List , Optional , Tuple , Union
6+ from typing import Callable , List , Optional , Tuple , Union
77
88import torch
99from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
2828
2929
3030@typing .overload
31- # pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
32- # arguments of overload defined on line `32`.
3331def _perturb_func (inputs : Tuple [Tensor , ...]) -> Tuple [Tensor , ...]: ...
3432
3533
3634@typing .overload
37- # pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible
38- # arguments of overload defined on line `28`.
3935def _perturb_func (inputs : Tensor ) -> Tensor : ...
4036
4137
4238def _perturb_func (
43- inputs : TensorOrTupleOfTensorsGeneric ,
39+ inputs : Union [ Tensor , Tuple [ Tensor , ...]] ,
4440) -> Union [Tensor , Tuple [Tensor , ...]]:
4541 def perturb_ratio (input : Tensor ) -> Tensor :
4642 return (
@@ -55,7 +51,7 @@ def perturb_ratio(input: Tensor) -> Tensor:
5551 input1 = inputs [0 ]
5652 input2 = inputs [1 ]
5753 else :
58- input1 = cast ( Tensor , inputs )
54+ input1 = inputs
5955
6056 perturbed_input1 = input1 + perturb_ratio (input1 )
6157
@@ -283,12 +279,13 @@ def test_classification_sensitivity_tpl_target_w_baseline(self) -> None:
283279
284280 def sensitivity_max_assert (
285281 self ,
286- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
287- expl_func : Callable ,
282+ expl_func : Callable [..., Union [Tensor , Tuple [Tensor , ...]]],
288283 inputs : TensorOrTupleOfTensorsGeneric ,
289284 expected_sensitivity : Tensor ,
290- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
291- perturb_func : Callable = _perturb_func ,
285+ perturb_func : Union [
286+ Callable [[Tensor ], Tensor ],
287+ Callable [[Tuple [Tensor , ...]], Tuple [Tensor , ...]],
288+ ] = _perturb_func ,
292289 n_perturb_samples : int = 5 ,
293290 max_examples_per_batch : Optional [int ] = None ,
294291 baselines : Optional [BaselineType ] = None ,
0 commit comments