Skip to content

Commit 18b209d

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Fix remaining pyre errors in infidelity.py
Summary: Fix pyre/mypy errors in infidelity.py. Introduce new BaselineTupleType Differential Revision: D64998803
1 parent e3cf2a1 commit 18b209d

File tree

2 files changed

+62
-71
lines changed

2 files changed

+62
-71
lines changed

captum/_utils/typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
1616
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
1717
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
18-
BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
18+
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
19+
BaselineType = Union[None, Tensor, int, float, BaselineTupleType]
1920

2021
TensorLikeList1D = List[float]
2122
TensorLikeList2D = List[TensorLikeList1D]

captum/metrics/_core/infidelity.py

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
ExpansionTypes,
1616
safe_div,
1717
)
18-
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
18+
from captum._utils.typing import (
19+
BaselineTupleType,
20+
BaselineType,
21+
TargetType,
22+
TensorOrTupleOfTensorsGeneric,
23+
)
1924
from captum.log import log_usage
2025
from captum.metrics._utils.batching import _divide_and_aggregate_metrics
2126
from torch import Tensor
@@ -35,14 +40,14 @@ def infidelity_perturb_func_decorator(
3540
]:
3641
r"""An auxiliary, decorator function that helps with computing
3742
perturbations given perturbed inputs. It can be useful for cases
38-
when `pertub_func` returns only perturbed inputs and we
43+
when `perturb_func` returns only perturbed inputs and we
3944
internally compute the perturbations as
4045
(input - perturbed_input) / (input - baseline) if
4146
multiply_by_inputs is set to True and
4247
(input - perturbed_input) otherwise.
4348
44-
If users decorate their `pertub_func` with
45-
`@infidelity_perturb_func_decorator` function then their `pertub_func`
49+
If users decorate their `perturb_func` with
50+
`@infidelity_perturb_func_decorator` function then their `perturb_func`
4651
needs to only return perturbed inputs.
4752
4853
Args:
@@ -54,15 +59,15 @@ def infidelity_perturb_func_decorator(
5459
"""
5560

5661
def sub_infidelity_perturb_func_decorator(
57-
pertub_func: Callable[..., TensorOrTupleOfTensorsGeneric]
62+
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric]
5863
) -> Callable[
5964
[TensorOrTupleOfTensorsGeneric, BaselineType],
6065
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
6166
]:
6267
r"""
6368
Args:
6469
65-
pertub_func(Callable): Input perturbation function that takes inputs
70+
perturb_func(Callable): Input perturbation function that takes inputs
6671
and optionally baselines and returns perturbed inputs
6772
6873
Returns:
@@ -87,9 +92,9 @@ def default_perturb_func(
8792
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
8893
r""" """
8994
inputs_perturbed: TensorOrTupleOfTensorsGeneric = (
90-
pertub_func(inputs, baselines)
95+
perturb_func(inputs, baselines)
9196
if baselines is not None
92-
else pertub_func(inputs)
97+
else perturb_func(inputs)
9398
)
9499
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)
95100
inputs_formatted = _format_tensor_into_tuples(inputs)
@@ -135,16 +140,14 @@ def default_perturb_func(
135140

136141
@log_usage()
137142
def infidelity(
138-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
139-
forward_func: Callable,
143+
forward_func: Callable[..., Tensor],
140144
perturb_func: Callable[
141145
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
142146
],
143147
inputs: TensorOrTupleOfTensorsGeneric,
144148
attributions: TensorOrTupleOfTensorsGeneric,
145149
baselines: BaselineType = None,
146-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
147-
additional_forward_args: Any = None,
150+
additional_forward_args: object = None,
148151
target: TargetType = None,
149152
n_perturb_samples: int = 10,
150153
max_examples_per_batch: Optional[int] = None,
@@ -417,38 +420,35 @@ def infidelity(
417420
>>> infid = infidelity(net, perturb_fn, input, attribution)
418421
"""
419422
# perform argument formattings
420-
inputs = _format_tensor_into_tuples(inputs) # type: ignore
423+
inputs_formatted = _format_tensor_into_tuples(inputs)
424+
baselines_formatted: BaselineTupleType = None
421425
if baselines is not None:
422-
baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs))
426+
baselines_formatted = _format_baseline(baselines, inputs_formatted)
423427
additional_forward_args = _format_additional_forward_args(additional_forward_args)
424-
attributions = _format_tensor_into_tuples(attributions) # type: ignore
428+
attributions_formatted = _format_tensor_into_tuples(attributions)
425429

426430
# Make sure that inputs and corresponding attributions have matching sizes.
427-
assert len(inputs) == len(attributions), (
428-
"""The number of tensors in the inputs and
429-
attributions must match. Found number of tensors in the inputs is: {} and in the
430-
attributions: {}"""
431-
).format(len(inputs), len(attributions))
432-
for inp, attr in zip(inputs, attributions):
431+
assert len(inputs_formatted) == len(attributions_formatted), (
432+
"The number of tensors in the inputs and attributions must match. "
433+
f"Found number of tensors in the inputs is: {len(inputs_formatted)} and in "
434+
f"the attributions: {len(attributions_formatted)}"
435+
)
436+
for inp, attr in zip(inputs_formatted, attributions_formatted):
433437
assert inp.shape == attr.shape, (
434-
"""Inputs and attributions must have
435-
matching shapes. One of the input tensor's shape is {} and the
436-
attribution tensor's shape is: {}"""
437-
# pyre-fixme[16]: Module `attr` has no attribute `shape`.
438-
).format(inp.shape, attr.shape)
438+
"Inputs and attributions must have matching shapes. "
439+
f"One of the input tensor's shape is {inp.shape} and the "
440+
f"attribution tensor's shape is: {attr.shape}"
441+
)
439442

440-
bsz = inputs[0].size(0)
443+
bsz = inputs_formatted[0].size(0)
441444

442445
_next_infidelity_tensors = _make_next_infidelity_tensors_func(
443446
forward_func,
444447
bsz,
445-
# error: Argument 3 to "_make_next_infidelity_tensors_func" has incompatible
446-
# type "Callable[..., tuple[Tensor, Tensor]]"; expected
447-
# "Callable[..., tuple[tuple[Tensor, ...], tuple[Tensor, ...]]]" [arg-type]
448-
perturb_func, # type: ignore
449-
inputs,
450-
baselines,
451-
attributions,
448+
perturb_func,
449+
inputs_formatted,
450+
baselines_formatted,
451+
attributions_formatted,
452452
additional_forward_args,
453453
target,
454454
normalize,
@@ -458,7 +458,7 @@ def infidelity(
458458
# if not normalize, directly return aggrgated MSE ((a-b)^2,)
459459
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
460460
agg_tensors = _divide_and_aggregate_metrics(
461-
cast(Tuple[Tensor, ...], inputs),
461+
inputs_formatted,
462462
n_perturb_samples,
463463
_next_infidelity_tensors,
464464
agg_func=_sum_infidelity_tensors,
@@ -472,11 +472,7 @@ def infidelity(
472472
beta = safe_div(beta_num, beta_denorm)
473473

474474
infidelity_values = (
475-
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
476-
# `int`.
477-
beta**2 * agg_tensors[0]
478-
- 2 * beta * agg_tensors[1]
479-
+ agg_tensors[2]
475+
beta * beta * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2]
480476
)
481477
else:
482478
infidelity_values = agg_tensors[0]
@@ -491,8 +487,8 @@ def _generate_perturbations(
491487
perturb_func: Callable[
492488
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
493489
],
494-
inputs: TensorOrTupleOfTensorsGeneric,
495-
baselines: BaselineType,
490+
inputs: Tuple[Tensor, ...],
491+
baselines: BaselineTupleType,
496492
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
497493
r"""
498494
The perturbations are generated for each example
@@ -507,14 +503,12 @@ def call_perturb_func() -> (
507503
Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
508504
):
509505
r""" """
510-
baselines_pert = None
506+
baselines_pert: BaselineType = None
511507
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
512508
if len(inputs_expanded) == 1:
513509
inputs_pert = inputs_expanded[0]
514510
if baselines_expanded is not None:
515-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
516-
# parameter.
517-
baselines_pert = cast(Tuple, baselines_expanded)[0]
511+
baselines_pert = baselines_expanded[0]
518512
else:
519513
inputs_pert = inputs_expanded
520514
baselines_pert = baselines_expanded
@@ -539,9 +533,7 @@ def call_perturb_func() -> (
539533
and baseline.shape[0] > 1
540534
else baseline
541535
)
542-
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
543-
# parameter.
544-
for input, baseline in zip(inputs, cast(Tuple, baselines))
536+
for input, baseline in zip(inputs, baselines)
545537
)
546538

547539
return call_perturb_func()
@@ -554,34 +546,32 @@ def _validate_inputs_and_perturbations(
554546
) -> None:
555547
# asserts the sizes of the perturbations and inputs
556548
assert len(perturbations) == len(inputs), (
557-
"""The number of perturbed
558-
inputs and corresponding perturbations must have the same number of
559-
elements. Found number of inputs is: {} and perturbations:
560-
{}"""
561-
).format(len(perturbations), len(inputs))
549+
"The number of perturbed "
550+
"inputs and corresponding perturbations must have the same number of "
551+
f"elements. Found number of inputs is: {len(perturbations)} and "
552+
f"perturbations: {len(inputs)}"
553+
)
562554

563555
# asserts the shapes of the perturbations and perturbed inputs
564556
for perturb, input_perturbed in zip(perturbations, inputs_perturbed):
565557
assert perturb[0].shape == input_perturbed[0].shape, (
566-
"""Perturbed input
567-
and corresponding perturbation must have the same shape and
568-
dimensionality. Found perturbation shape is: {} and the input shape
569-
is: {}"""
570-
).format(perturb[0].shape, input_perturbed[0].shape)
558+
"Perturbed input "
559+
"and corresponding perturbation must have the same shape and "
560+
f"dimensionality. Found perturbation shape is: {perturb[0].shape} "
561+
f"and the input shape is: {input_perturbed[0].shape}"
562+
)
571563

572564

573565
def _make_next_infidelity_tensors_func(
574-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
575-
forward_func: Callable,
566+
forward_func: Callable[..., Tensor],
576567
bsz: int,
577568
perturb_func: Callable[
578569
..., Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
579570
],
580-
inputs: TensorOrTupleOfTensorsGeneric,
581-
baselines: BaselineType,
582-
attributions: TensorOrTupleOfTensorsGeneric,
583-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
584-
additional_forward_args: Any = None,
571+
inputs: Tuple[Tensor, ...],
572+
baselines: BaselineTupleType,
573+
attributions: Tuple[Tensor, ...],
574+
additional_forward_args: object = None,
585575
target: TargetType = None,
586576
normalize: bool = False,
587577
) -> Callable[[int], Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]]:
@@ -597,7 +587,7 @@ def _next_infidelity_tensors(
597587
inputs_perturbed_formatted = _format_tensor_into_tuples(inputs_perturbed)
598588

599589
_validate_inputs_and_perturbations(
600-
cast(Tuple[Tensor, ...], inputs),
590+
inputs,
601591
inputs_perturbed_formatted,
602592
perturbations_formatted,
603593
)
@@ -666,7 +656,7 @@ def _next_infidelity_tensors(
666656
return _next_infidelity_tensors
667657

668658

669-
# pyre-fixme[3]: Return type must be annotated.
670-
# pyre-fixme[2]: Parameter must be annotated.
671-
def _sum_infidelity_tensors(agg_tensors, tensors):
659+
def _sum_infidelity_tensors(
660+
agg_tensors: Tuple[Tensor, ...], tensors: Tuple[Tensor, ...]
661+
) -> Tuple[Tensor, ...]:
672662
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors))

0 commit comments

Comments
 (0)