Skip to content

Commit 2f2d40c

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Remove mypy note from infidelity.py (#1415)
Summary: Pull Request resolved: #1415 Adds enough typing to get rid of `captum/metrics/_core/infidelity.py:498: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs [annotation-unchecked]` Reviewed By: cyrjano Differential Revision: D64998800 fbshipit-source-id: 70b5d0a7ecd64dc1064d788b14f149bd11497e1b
1 parent dff8473 commit 2f2d40c

File tree

1 file changed

+62
-50
lines changed

1 file changed

+62
-50
lines changed

captum/metrics/_core/infidelity.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,24 @@
2121
from torch import Tensor
2222

2323

24-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
25-
def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable:
24+
def infidelity_perturb_func_decorator(
25+
multiply_by_inputs: bool = True,
26+
# pyre-ignore[34]: The type variable `Variable[TensorOrTupleOfTensorsGeneric
27+
# <: [torch._tensor.Tensor, typing.Tuple[torch._tensor.Tensor, ...]]]` isn't
28+
# present in the function's parameters.
29+
) -> Callable[
30+
[Callable[..., TensorOrTupleOfTensorsGeneric]],
31+
Callable[
32+
[TensorOrTupleOfTensorsGeneric, BaselineType],
33+
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
34+
],
35+
]:
2636
r"""An auxiliary, decorator function that helps with computing
2737
perturbations given perturbed inputs. It can be useful for cases
2838
when `pertub_func` returns only perturbed inputs and we
2939
internally compute the perturbations as
3040
(input - perturbed_input) / (input - baseline) if
31-
multipy_by_inputs is set to True and
41+
multiply_by_inputs is set to True and
3242
(input - perturbed_input) otherwise.
3343
3444
If users decorate their `pertub_func` with
@@ -37,14 +47,18 @@ def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callabl
3747
3848
Args:
3949
40-
multipy_by_inputs (bool): Indicates whether model inputs'
50+
multiply_by_inputs (bool): Indicates whether model inputs'
4151
multiplier is factored in the computation of
4252
attribution scores.
4353
4454
"""
4555

46-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
47-
def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
56+
def sub_infidelity_perturb_func_decorator(
57+
pertub_func: Callable[..., TensorOrTupleOfTensorsGeneric]
58+
) -> Callable[
59+
[TensorOrTupleOfTensorsGeneric, BaselineType],
60+
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
61+
]:
4862
r"""
4963
Args:
5064
@@ -68,23 +82,18 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
6882
6983
"""
7084

71-
# pyre-fixme[3]: Return type must be annotated.
7285
def default_perturb_func(
7386
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None
74-
):
87+
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
7588
r""" """
76-
inputs_perturbed = (
89+
inputs_perturbed: TensorOrTupleOfTensorsGeneric = (
7790
pertub_func(inputs, baselines)
7891
if baselines is not None
7992
else pertub_func(inputs)
8093
)
81-
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)
82-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used
83-
# as `Tuple[Tensor, ...]`.
84-
inputs = _format_tensor_into_tuples(inputs)
85-
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
86-
# `TensorOrTupleOfTensorsGeneric`.
87-
baselines = _format_baseline(baselines, inputs)
94+
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)
95+
inputs_formatted = _format_tensor_into_tuples(inputs)
96+
baselines = _format_baseline(baselines, inputs_formatted)
8897
if baselines is None:
8998
perturbations = tuple(
9099
(
@@ -93,12 +102,12 @@ def default_perturb_func(
93102
input,
94103
default_denom=1.0,
95104
)
96-
if multipy_by_inputs
105+
if multiply_by_inputs
97106
else input - input_perturbed
98107
)
99-
# pyre-fixme[6]: For 2nd argument expected
100-
# `Iterable[Variable[_T2]]` but got `None`.
101-
for input, input_perturbed in zip(inputs, inputs_perturbed)
108+
for input, input_perturbed in zip(
109+
inputs_formatted, inputs_perturbed_formatted
110+
)
102111
)
103112
else:
104113
perturbations = tuple(
@@ -108,18 +117,16 @@ def default_perturb_func(
108117
input - baseline,
109118
default_denom=1.0,
110119
)
111-
if multipy_by_inputs
120+
if multiply_by_inputs
112121
else input - input_perturbed
113122
)
114123
for input, input_perturbed, baseline in zip(
115-
inputs,
116-
# pyre-fixme[6]: For 2nd argument expected
117-
# `Iterable[Variable[_T2]]` but got `None`.
118-
inputs_perturbed,
124+
inputs_formatted,
125+
inputs_perturbed_formatted,
119126
baselines,
120127
)
121128
)
122-
return perturbations, inputs_perturbed
129+
return perturbations, inputs_perturbed_formatted
123130

124131
return default_perturb_func
125132

@@ -130,8 +137,9 @@ def default_perturb_func(
130137
def infidelity(
131138
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
132139
forward_func: Callable,
133-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
134-
perturb_func: Callable,
140+
perturb_func: Callable[
141+
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
142+
],
135143
inputs: TensorOrTupleOfTensorsGeneric,
136144
attributions: TensorOrTupleOfTensorsGeneric,
137145
baselines: BaselineType = None,
@@ -188,25 +196,25 @@ def infidelity(
188196
189197
>>> from captum.metrics import infidelity_perturb_func_decorator
190198
191-
>>> @infidelity_perturb_func_decorator(<multipy_by_inputs flag>)
199+
>>> @infidelity_perturb_func_decorator(<multiply_by_inputs flag>)
192200
>>> def my_perturb_func(inputs):
193201
>>> <MY-LOGIC-HERE>
194202
>>> return perturbed_inputs
195203
196-
In case `multipy_by_inputs` is False we compute perturbations by
197-
`input - perturbed_input` difference and in case `multipy_by_inputs`
204+
In case `multiply_by_inputs` is False we compute perturbations by
205+
`input - perturbed_input` difference and in case `multiply_by_inputs`
198206
flag is True we compute it by dividing
199207
(input - perturbed_input) by (input - baselines).
200208
The user needs to only return perturbed inputs in `perturb_func`
201209
as described above.
202210
203211
`infidelity_perturb_func_decorator` needs to be used with
204-
`multipy_by_inputs` flag set to False in case infidelity
212+
`multiply_by_inputs` flag set to False in case infidelity
205213
score is being computed for attribution maps that are local aka
206214
that do not factor in inputs in the final attribution score.
207215
Such attribution algorithms include Saliency, GradCam, Guided Backprop,
208216
or Integrated Gradients and DeepLift attribution scores that are already
209-
computed with `multipy_by_inputs=False` flag.
217+
computed with `multiply_by_inputs=False` flag.
210218
211219
If there are more than one inputs passed to infidelity function those
212220
will be passed to `perturb_func` as tuples in the same order as they
@@ -283,10 +291,10 @@ def infidelity(
283291
meaning that the inputs multiplier isn't factored in the
284292
attribution scores.
285293
This can be done duing the definition of the attribution algorithm
286-
by passing `multipy_by_inputs=False` flag.
294+
by passing `multiply_by_inputs=False` flag.
287295
For example in case of Integrated Gradients (IG) we can obtain
288296
local attribution scores if we define the constructor of IG as:
289-
ig = IntegratedGradients(multipy_by_inputs=False)
297+
ig = IntegratedGradients(multiply_by_inputs=False)
290298
291299
Some attribution algorithms are inherently local.
292300
Examples of inherently local attribution methods include:
@@ -434,7 +442,10 @@ def infidelity(
434442
_next_infidelity_tensors = _make_next_infidelity_tensors_func(
435443
forward_func,
436444
bsz,
437-
perturb_func,
445+
# error: Argument 3 to "_make_next_infidelity_tensors_func" has incompatible
446+
# type "Callable[..., tuple[Tensor, Tensor]]"; expected
447+
# "Callable[..., tuple[tuple[Tensor, ...], tuple[Tensor, ...]]]" [arg-type]
448+
perturb_func, # type: ignore
438449
inputs,
439450
baselines,
440451
attributions,
@@ -477,8 +488,9 @@ def infidelity(
477488

478489
def _generate_perturbations(
479490
current_n_perturb_samples: int,
480-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
481-
perturb_func: Callable,
491+
perturb_func: Callable[
492+
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
493+
],
482494
inputs: TensorOrTupleOfTensorsGeneric,
483495
baselines: BaselineType,
484496
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
@@ -491,8 +503,9 @@ def _generate_perturbations(
491503
repeated instances per example.
492504
"""
493505

494-
# pyre-fixme[3]: Return type must be annotated.
495-
def call_perturb_func():
506+
def call_perturb_func() -> (
507+
Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
508+
):
496509
r""" """
497510
baselines_pert = None
498511
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
@@ -561,8 +574,9 @@ def _make_next_infidelity_tensors_func(
561574
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
562575
forward_func: Callable,
563576
bsz: int,
564-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
565-
perturb_func: Callable,
577+
perturb_func: Callable[
578+
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
579+
],
566580
inputs: TensorOrTupleOfTensorsGeneric,
567581
baselines: BaselineType,
568582
attributions: TensorOrTupleOfTensorsGeneric,
@@ -579,15 +593,13 @@ def _next_infidelity_tensors(
579593
current_n_perturb_samples, perturb_func, inputs, baselines
580594
)
581595

582-
perturbations = _format_tensor_into_tuples(perturbations)
583-
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)
596+
perturbations_formatted = _format_tensor_into_tuples(perturbations)
597+
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)
584598

585599
_validate_inputs_and_perturbations(
586600
cast(Tuple[Tensor, ...], inputs),
587-
# pyre-fixme[22]: The cast is redundant.
588-
cast(Tuple[Tensor, ...], inputs_perturbed),
589-
# pyre-fixme[22]: The cast is redundant.
590-
cast(Tuple[Tensor, ...], perturbations),
601+
inputs_perturbed_formatted,
602+
perturbations_formatted,
591603
)
592604

593605
targets_expanded = _expand_target(
@@ -603,7 +615,7 @@ def _next_infidelity_tensors(
603615

604616
inputs_perturbed_fwd = _run_forward(
605617
forward_func,
606-
inputs_perturbed,
618+
inputs_perturbed_formatted,
607619
targets_expanded,
608620
additional_forward_args_expanded,
609621
)
@@ -624,7 +636,7 @@ def _next_infidelity_tensors(
624636
attributions_times_perturb = tuple(
625637
(attribution_expanded * perturbation).view(attribution_expanded.size(0), -1)
626638
for attribution_expanded, perturbation in zip(
627-
attributions_expanded, perturbations
639+
attributions_expanded, perturbations_formatted
628640
)
629641
)
630642

0 commit comments

Comments
 (0)