2121from torch import Tensor
2222
2323
24- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
25- def infidelity_perturb_func_decorator (multipy_by_inputs : bool = True ) -> Callable :
24+ def infidelity_perturb_func_decorator (
25+ multiply_by_inputs : bool = True ,
26+ # pyre-ignore[34]: The type variable `Variable[TensorOrTupleOfTensorsGeneric
27+ # <: [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` isn't
28+ # present in the function's parameters.
29+ ) -> Callable [
30+ [Callable [..., TensorOrTupleOfTensorsGeneric ]],
31+ Callable [
32+ [TensorOrTupleOfTensorsGeneric , BaselineType ],
33+ Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]],
34+ ],
35+ ]:
2636 r"""An auxiliary, decorator function that helps with computing
2737 perturbations given perturbed inputs. It can be useful for cases
2838 when `pertub_func` returns only perturbed inputs and we
2939 internally compute the perturbations as
3040 (input - perturbed_input) / (input - baseline) if
31- multipy_by_inputs is set to True and
41+ multiply_by_inputs is set to True and
3242 (input - perturbed_input) otherwise.
3343
3444 If users decorate their `pertub_func` with
@@ -37,14 +47,18 @@ def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callabl
3747
3848 Args:
3949
40- multipy_by_inputs (bool): Indicates whether model inputs'
50+ multiply_by_inputs (bool): Indicates whether model inputs'
4151 multiplier is factored in the computation of
4252 attribution scores.
4353
4454 """
4555
46- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
47- def sub_infidelity_perturb_func_decorator (pertub_func : Callable ) -> Callable :
56+ def sub_infidelity_perturb_func_decorator (
57+ pertub_func : Callable [..., TensorOrTupleOfTensorsGeneric ]
58+ ) -> Callable [
59+ [TensorOrTupleOfTensorsGeneric , BaselineType ],
60+ Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]],
61+ ]:
4862 r"""
4963 Args:
5064
@@ -68,23 +82,18 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
6882
6983 """
7084
71- # pyre-fixme[3]: Return type must be annotated.
7285 def default_perturb_func (
7386 inputs : TensorOrTupleOfTensorsGeneric , baselines : BaselineType = None
74- ):
87+ ) -> Tuple [ Tuple [ Tensor , ...], Tuple [ Tensor , ...]] :
7588 r""" """
76- inputs_perturbed = (
89+ inputs_perturbed : TensorOrTupleOfTensorsGeneric = (
7790 pertub_func (inputs , baselines )
7891 if baselines is not None
7992 else pertub_func (inputs )
8093 )
81- inputs_perturbed = _format_tensor_into_tuples (inputs_perturbed )
82- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used
83- # as `Tuple[Tensor, ...]`.
84- inputs = _format_tensor_into_tuples (inputs )
85- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
86- # `TensorOrTupleOfTensorsGeneric`.
87- baselines = _format_baseline (baselines , inputs )
94+ inputs_perturbed_formatted = _format_tensor_into_tuples (inputs_perturbed )
95+ inputs_formatted = _format_tensor_into_tuples (inputs )
96+ baselines = _format_baseline (baselines , inputs_formatted )
8897 if baselines is None :
8998 perturbations = tuple (
9099 (
@@ -93,12 +102,12 @@ def default_perturb_func(
93102 input ,
94103 default_denom = 1.0 ,
95104 )
96- if multipy_by_inputs
105+ if multiply_by_inputs
97106 else input - input_perturbed
98107 )
99- # pyre-fixme[6]: For 2nd argument expected
100- # `Iterable[Variable[_T2]]` but got `None`.
101- for input , input_perturbed in zip ( inputs , inputs_perturbed )
108+ for input , input_perturbed in zip (
109+ inputs_formatted , inputs_perturbed_formatted
110+ )
102111 )
103112 else :
104113 perturbations = tuple (
@@ -108,18 +117,16 @@ def default_perturb_func(
108117 input - baseline ,
109118 default_denom = 1.0 ,
110119 )
111- if multipy_by_inputs
120+ if multiply_by_inputs
112121 else input - input_perturbed
113122 )
114123 for input , input_perturbed , baseline in zip (
115- inputs ,
116- # pyre-fixme[6]: For 2nd argument expected
117- # `Iterable[Variable[_T2]]` but got `None`.
118- inputs_perturbed ,
124+ inputs_formatted ,
125+ inputs_perturbed_formatted ,
119126 baselines ,
120127 )
121128 )
122- return perturbations , inputs_perturbed
129+ return perturbations , inputs_perturbed_formatted
123130
124131 return default_perturb_func
125132
@@ -130,8 +137,9 @@ def default_perturb_func(
130137def infidelity (
131138 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
132139 forward_func : Callable ,
133- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
134- perturb_func : Callable ,
140+ perturb_func : Callable [
141+ ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
142+ ],
135143 inputs : TensorOrTupleOfTensorsGeneric ,
136144 attributions : TensorOrTupleOfTensorsGeneric ,
137145 baselines : BaselineType = None ,
@@ -188,25 +196,25 @@ def infidelity(
188196
189197 >>> from captum.metrics import infidelity_perturb_func_decorator
190198
191- >>> @infidelity_perturb_func_decorator(<multipy_by_inputs flag>)
199+ >>> @infidelity_perturb_func_decorator(<multiply_by_inputs flag>)
192200 >>> def my_perturb_func(inputs):
193201 >>> <MY-LOGIC-HERE>
194202 >>> return perturbed_inputs
195203
196- In case `multipy_by_inputs ` is False we compute perturbations by
197- `input - perturbed_input` difference and in case `multipy_by_inputs `
204+ In case `multiply_by_inputs ` is False we compute perturbations by
205+ `input - perturbed_input` difference and in case `multiply_by_inputs `
198206 flag is True we compute it by dividing
199207 (input - perturbed_input) by (input - baselines).
200208 The user needs to only return perturbed inputs in `perturb_func`
201209 as described above.
202210
203211 `infidelity_perturb_func_decorator` needs to be used with
204- `multipy_by_inputs ` flag set to False in case infidelity
212+ `multiply_by_inputs ` flag set to False in case infidelity
205213 score is being computed for attribution maps that are local aka
206214 that do not factor in inputs in the final attribution score.
207215 Such attribution algorithms include Saliency, GradCam, Guided Backprop,
208216 or Integrated Gradients and DeepLift attribution scores that are already
209- computed with `multipy_by_inputs =False` flag.
217+ computed with `multiply_by_inputs =False` flag.
210218
211219 If there are more than one inputs passed to infidelity function those
212220 will be passed to `perturb_func` as tuples in the same order as they
@@ -283,10 +291,10 @@ def infidelity(
283291 meaning that the inputs multiplier isn't factored in the
284292 attribution scores.
285293 This can be done duing the definition of the attribution algorithm
286- by passing `multipy_by_inputs =False` flag.
294+ by passing `multiply_by_inputs =False` flag.
287295 For example in case of Integrated Gradients (IG) we can obtain
288296 local attribution scores if we define the constructor of IG as:
289- ig = IntegratedGradients(multipy_by_inputs =False)
297+ ig = IntegratedGradients(multiply_by_inputs =False)
290298
291299 Some attribution algorithms are inherently local.
292300 Examples of inherently local attribution methods include:
@@ -434,7 +442,10 @@ def infidelity(
434442 _next_infidelity_tensors = _make_next_infidelity_tensors_func (
435443 forward_func ,
436444 bsz ,
437- perturb_func ,
445+ # error: Argument 3 to "_make_next_infidelity_tensors_func" has incompatible
446+ # type "Callable[..., tuple[Tensor, Tensor]]"; expected
447+ # "Callable[..., tuple[tuple[Tensor, ...], tuple[Tensor, ...]]]" [arg-type]
448+ perturb_func , # type: ignore
438449 inputs ,
439450 baselines ,
440451 attributions ,
@@ -477,8 +488,9 @@ def infidelity(
477488
478489def _generate_perturbations (
479490 current_n_perturb_samples : int ,
480- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
481- perturb_func : Callable ,
491+ perturb_func : Callable [
492+ ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
493+ ],
482494 inputs : TensorOrTupleOfTensorsGeneric ,
483495 baselines : BaselineType ,
484496) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
@@ -491,8 +503,9 @@ def _generate_perturbations(
491503 repeated instances per example.
492504 """
493505
494- # pyre-fixme[3]: Return type must be annotated.
495- def call_perturb_func ():
506+ def call_perturb_func () -> (
507+ Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
508+ ):
496509 r""" """
497510 baselines_pert = None
498511 inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
@@ -561,8 +574,9 @@ def _make_next_infidelity_tensors_func(
561574 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
562575 forward_func : Callable ,
563576 bsz : int ,
564- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
565- perturb_func : Callable ,
577+ perturb_func : Callable [
578+ ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
579+ ],
566580 inputs : TensorOrTupleOfTensorsGeneric ,
567581 baselines : BaselineType ,
568582 attributions : TensorOrTupleOfTensorsGeneric ,
@@ -579,15 +593,13 @@ def _next_infidelity_tensors(
579593 current_n_perturb_samples , perturb_func , inputs , baselines
580594 )
581595
582- perturbations = _format_tensor_into_tuples (perturbations )
583- inputs_perturbed = _format_tensor_into_tuples (inputs_perturbed )
596+ perturbations_formatted = _format_tensor_into_tuples (perturbations )
597+ inputs_perturbed_formatted = _format_tensor_into_tuples (inputs_perturbed )
584598
585599 _validate_inputs_and_perturbations (
586600 cast (Tuple [Tensor , ...], inputs ),
587- # pyre-fixme[22]: The cast is redundant.
588- cast (Tuple [Tensor , ...], inputs_perturbed ),
589- # pyre-fixme[22]: The cast is redundant.
590- cast (Tuple [Tensor , ...], perturbations ),
601+ inputs_perturbed_formatted ,
602+ perturbations_formatted ,
591603 )
592604
593605 targets_expanded = _expand_target (
@@ -603,7 +615,7 @@ def _next_infidelity_tensors(
603615
604616 inputs_perturbed_fwd = _run_forward (
605617 forward_func ,
606- inputs_perturbed ,
618+ inputs_perturbed_formatted ,
607619 targets_expanded ,
608620 additional_forward_args_expanded ,
609621 )
@@ -624,7 +636,7 @@ def _next_infidelity_tensors(
624636 attributions_times_perturb = tuple (
625637 (attribution_expanded * perturbation ).view (attribution_expanded .size (0 ), - 1 )
626638 for attribution_expanded , perturbation in zip (
627- attributions_expanded , perturbations
639+ attributions_expanded , perturbations_formatted
628640 )
629641 )
630642
0 commit comments