1515 ExpansionTypes ,
1616 safe_div ,
1717)
18- from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
18+ from captum ._utils .typing import (
19+ BaselineTupleType ,
20+ BaselineType ,
21+ TargetType ,
22+ TensorOrTupleOfTensorsGeneric ,
23+ )
1924from captum .log import log_usage
2025from captum .metrics ._utils .batching import _divide_and_aggregate_metrics
2126from torch import Tensor
@@ -35,14 +40,14 @@ def infidelity_perturb_func_decorator(
3540]:
3641 r"""An auxiliary, decorator function that helps with computing
3742 perturbations given perturbed inputs. It can be useful for cases
38- when `pertub_func ` returns only perturbed inputs and we
43+ when `perturb_func ` returns only perturbed inputs and we
3944 internally compute the perturbations as
4045 (input - perturbed_input) / (input - baseline) if
4146 multiply_by_inputs is set to True and
4247 (input - perturbed_input) otherwise.
4348
44- If users decorate their `pertub_func ` with
45- `@infidelity_perturb_func_decorator` function then their `pertub_func `
49+ If users decorate their `perturb_func ` with
50+ `@infidelity_perturb_func_decorator` function then their `perturb_func `
4651 needs to only return perturbed inputs.
4752
4853 Args:
@@ -54,15 +59,15 @@ def infidelity_perturb_func_decorator(
5459 """
5560
5661 def sub_infidelity_perturb_func_decorator (
57- pertub_func : Callable [..., TensorOrTupleOfTensorsGeneric ]
62+ perturb_func : Callable [..., TensorOrTupleOfTensorsGeneric ]
5863 ) -> Callable [
5964 [TensorOrTupleOfTensorsGeneric , BaselineType ],
6065 Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]],
6166 ]:
6267 r"""
6368 Args:
6469
65- pertub_func (Callable): Input perturbation function that takes inputs
70+ perturb_func (Callable): Input perturbation function that takes inputs
6671 and optionally baselines and returns perturbed inputs
6772
6873 Returns:
@@ -87,9 +92,9 @@ def default_perturb_func(
8792 ) -> Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
8893 r""" """
8994 inputs_perturbed : TensorOrTupleOfTensorsGeneric = (
90- pertub_func (inputs , baselines )
95+ perturb_func (inputs , baselines )
9196 if baselines is not None
92- else pertub_func (inputs )
97+ else perturb_func (inputs )
9398 )
9499 inputs_perturbed_formatted = _format_tensor_into_tuples (inputs_perturbed )
95100 inputs_formatted = _format_tensor_into_tuples (inputs )
@@ -135,16 +140,14 @@ def default_perturb_func(
135140
136141@log_usage ()
137142def infidelity (
138- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
139- forward_func : Callable ,
143+ forward_func : Callable [..., Tensor ],
140144 perturb_func : Callable [
141145 ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
142146 ],
143147 inputs : TensorOrTupleOfTensorsGeneric ,
144148 attributions : TensorOrTupleOfTensorsGeneric ,
145149 baselines : BaselineType = None ,
146- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
147- additional_forward_args : Any = None ,
150+ additional_forward_args : object = None ,
148151 target : TargetType = None ,
149152 n_perturb_samples : int = 10 ,
150153 max_examples_per_batch : Optional [int ] = None ,
@@ -417,38 +420,35 @@ def infidelity(
417420 >>> infid = infidelity(net, perturb_fn, input, attribution)
418421 """
419422 # perform argument formattings
420- inputs = _format_tensor_into_tuples (inputs ) # type: ignore
423+ inputs_formatted = _format_tensor_into_tuples (inputs )
424+ baselines_formatted : BaselineTupleType = None
421425 if baselines is not None :
422- baselines = _format_baseline (baselines , cast ( Tuple [ Tensor , ...], inputs ) )
426+ baselines_formatted = _format_baseline (baselines , inputs_formatted )
423427 additional_forward_args = _format_additional_forward_args (additional_forward_args )
424- attributions = _format_tensor_into_tuples (attributions ) # type: ignore
428+ attributions_formatted = _format_tensor_into_tuples (attributions )
425429
426430 # Make sure that inputs and corresponding attributions have matching sizes.
427- assert len (inputs ) == len (attributions ), (
428- """ The number of tensors in the inputs and
429- attributions must match. Found number of tensors in the inputs is: {} and in the
430- attributions: {}"" "
431- ). format ( len ( inputs ), len ( attributions ))
432- for inp , attr in zip (inputs , attributions ):
431+ assert len (inputs_formatted ) == len (attributions_formatted ), (
432+ "The number of tensors in the inputs and attributions must match. "
433+ f" Found number of tensors in the inputs is: { len ( inputs_formatted ) } and in "
434+ f"the attributions: { len ( attributions_formatted ) } "
435+ )
436+ for inp , attr in zip (inputs_formatted , attributions_formatted ):
433437 assert inp .shape == attr .shape , (
434- """Inputs and attributions must have
435- matching shapes. One of the input tensor's shape is {} and the
436- attribution tensor's shape is: {}"""
437- # pyre-fixme[16]: Module `attr` has no attribute `shape`.
438- ).format (inp .shape , attr .shape )
438+ "Inputs and attributions must have matching shapes. "
439+ f"One of the input tensor's shape is { inp .shape } and the "
440+ f"attribution tensor's shape is: { attr .shape } "
441+ )
439442
440- bsz = inputs [0 ].size (0 )
443+ bsz = inputs_formatted [0 ].size (0 )
441444
442445 _next_infidelity_tensors = _make_next_infidelity_tensors_func (
443446 forward_func ,
444447 bsz ,
445- # error: Argument 3 to "_make_next_infidelity_tensors_func" has incompatible
446- # type "Callable[..., tuple[Tensor, Tensor]]"; expected
447- # "Callable[..., tuple[tuple[Tensor, ...], tuple[Tensor, ...]]]" [arg-type]
448- perturb_func , # type: ignore
449- inputs ,
450- baselines ,
451- attributions ,
448+ perturb_func ,
449+ inputs_formatted ,
450+ baselines_formatted ,
451+ attributions_formatted ,
452452 additional_forward_args ,
453453 target ,
454454 normalize ,
@@ -458,7 +458,7 @@ def infidelity(
458458 # if not normalize, directly return aggrgated MSE ((a-b)^2,)
459459 # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
460460 agg_tensors = _divide_and_aggregate_metrics (
461- cast ( Tuple [ Tensor , ...], inputs ) ,
461+ inputs_formatted ,
462462 n_perturb_samples ,
463463 _next_infidelity_tensors ,
464464 agg_func = _sum_infidelity_tensors ,
@@ -472,11 +472,7 @@ def infidelity(
472472 beta = safe_div (beta_num , beta_denorm )
473473
474474 infidelity_values = (
475- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
476- # `int`.
477- beta ** 2 * agg_tensors [0 ]
478- - 2 * beta * agg_tensors [1 ]
479- + agg_tensors [2 ]
475+ beta * beta * agg_tensors [0 ] - 2 * beta * agg_tensors [1 ] + agg_tensors [2 ]
480476 )
481477 else :
482478 infidelity_values = agg_tensors [0 ]
@@ -491,8 +487,8 @@ def _generate_perturbations(
491487 perturb_func : Callable [
492488 ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
493489 ],
494- inputs : TensorOrTupleOfTensorsGeneric ,
495- baselines : BaselineType ,
490+ inputs : Tuple [ Tensor , ...] ,
491+ baselines : BaselineTupleType ,
496492) -> Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]:
497493 r"""
498494 The perturbations are generated for each example
@@ -507,14 +503,12 @@ def call_perturb_func() -> (
507503 Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
508504 ):
509505 r""" """
510- baselines_pert = None
506+ baselines_pert : BaselineType = None
511507 inputs_pert : Union [Tensor , Tuple [Tensor , ...]]
512508 if len (inputs_expanded ) == 1 :
513509 inputs_pert = inputs_expanded [0 ]
514510 if baselines_expanded is not None :
515- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
516- # parameter.
517- baselines_pert = cast (Tuple , baselines_expanded )[0 ]
511+ baselines_pert = baselines_expanded [0 ]
518512 else :
519513 inputs_pert = inputs_expanded
520514 baselines_pert = baselines_expanded
@@ -539,9 +533,7 @@ def call_perturb_func() -> (
539533 and baseline .shape [0 ] > 1
540534 else baseline
541535 )
542- # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
543- # parameter.
544- for input , baseline in zip (inputs , cast (Tuple , baselines ))
536+ for input , baseline in zip (inputs , baselines )
545537 )
546538
547539 return call_perturb_func ()
@@ -554,34 +546,32 @@ def _validate_inputs_and_perturbations(
554546) -> None :
555547 # asserts the sizes of the perturbations and inputs
556548 assert len (perturbations ) == len (inputs ), (
557- """ The number of perturbed
558- inputs and corresponding perturbations must have the same number of
559- elements. Found number of inputs is: {} and perturbations:
560- {}"" "
561- ). format ( len ( perturbations ), len ( inputs ))
549+ "The number of perturbed "
550+ " inputs and corresponding perturbations must have the same number of "
551+ f" elements. Found number of inputs is: { len ( perturbations ) } and "
552+ f"perturbations: { len ( inputs ) } "
553+ )
562554
563555 # asserts the shapes of the perturbations and perturbed inputs
564556 for perturb , input_perturbed in zip (perturbations , inputs_perturbed ):
565557 assert perturb [0 ].shape == input_perturbed [0 ].shape , (
566- """ Perturbed input
567- and corresponding perturbation must have the same shape and
568- dimensionality. Found perturbation shape is: {} and the input shape
569- is: {}"" "
570- ). format ( perturb [ 0 ]. shape , input_perturbed [ 0 ]. shape )
558+ "Perturbed input "
559+ " and corresponding perturbation must have the same shape and "
560+ f" dimensionality. Found perturbation shape is: { perturb [ 0 ]. shape } "
561+ f"and the input shape is: { input_perturbed [ 0 ]. shape } "
562+ )
571563
572564
573565def _make_next_infidelity_tensors_func (
574- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
575- forward_func : Callable ,
566+ forward_func : Callable [..., Tensor ],
576567 bsz : int ,
577568 perturb_func : Callable [
578569 ..., Tuple [TensorOrTupleOfTensorsGeneric , TensorOrTupleOfTensorsGeneric ]
579570 ],
580- inputs : TensorOrTupleOfTensorsGeneric ,
581- baselines : BaselineType ,
582- attributions : TensorOrTupleOfTensorsGeneric ,
583- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
584- additional_forward_args : Any = None ,
571+ inputs : Tuple [Tensor , ...],
572+ baselines : BaselineTupleType ,
573+ attributions : Tuple [Tensor , ...],
574+ additional_forward_args : object = None ,
585575 target : TargetType = None ,
586576 normalize : bool = False ,
587577) -> Callable [[int ], Union [Tuple [Tensor ], Tuple [Tensor , Tensor , Tensor ]]]:
@@ -597,7 +587,7 @@ def _next_infidelity_tensors(
597587 inputs_perturbed_formatted = _format_tensor_into_tuples (inputs_perturbed )
598588
599589 _validate_inputs_and_perturbations (
600- cast ( Tuple [ Tensor , ...], inputs ) ,
590+ inputs ,
601591 inputs_perturbed_formatted ,
602592 perturbations_formatted ,
603593 )
@@ -666,7 +656,7 @@ def _next_infidelity_tensors(
666656 return _next_infidelity_tensors
667657
668658
669- # pyre-fixme[3]: Return type must be annotated.
670- # pyre-fixme[2]: Parameter must be annotated.
671- def _sum_infidelity_tensors ( agg_tensors , tensors ) :
659+ def _sum_infidelity_tensors (
660+ agg_tensors : Tuple [ Tensor , ...], tensors : Tuple [ Tensor , ...]
661+ ) -> Tuple [ Tensor , ...] :
672662 return tuple (agg_t + t for agg_t , t in zip (agg_tensors , tensors ))
0 commit comments