22
33# pyre-strict
44import typing
5- from typing import Any , Callable , cast , Dict , Literal , Optional , Sequence , Tuple , Union
5+ from typing import (
6+ Any ,
7+ Callable ,
8+ cast ,
9+ Dict ,
10+ List ,
11+ Literal ,
12+ Optional ,
13+ Sequence ,
14+ Tuple ,
15+ Union ,
16+ )
617
718import torch
819from captum ._utils .common import (
@@ -96,6 +107,7 @@ def __init__(
96107
97108 # Ignoring mypy error for inconsistent signature with DeepLift
98109 @typing .overload # type: ignore
110+ @log_usage ()
99111 def attribute (
100112 self ,
101113 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -110,6 +122,7 @@ def attribute(
110122 ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]: ...
111123
112124 @typing .overload
125+ @log_usage ()
113126 def attribute (
114127 self ,
115128 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -123,8 +136,6 @@ def attribute(
123136 ) -> Union [Tensor , Tuple [Tensor , ...]]: ...
124137
125138 @log_usage ()
126- # pyre-fixme[43]: This definition does not have the same decorators as the
127- # preceding overload(s).
128139 def attribute (
129140 self ,
130141 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -321,8 +332,9 @@ def attribute(
321332 additional_forward_args ,
322333 )
323334
324- # pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
325- def chunk_output_fn (out : TensorOrTupleOfTensorsGeneric ) -> Sequence :
335+ def chunk_output_fn (
336+ out : TensorOrTupleOfTensorsGeneric ,
337+ ) -> Sequence [Union [Tensor , List [Tensor ]]]:
326338 if isinstance (out , Tensor ):
327339 return out .chunk (2 )
328340 return tuple (out_sub .chunk (2 ) for out_sub in out )
@@ -434,8 +446,7 @@ def __init__(
434446
435447 # Ignoring mypy error for inconsistent signature with DeepLiftShap
436448 @typing .overload # type: ignore
437- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
438- # arguments of overload defined on line `453`.
449+ @log_usage ()
439450 def attribute (
440451 self ,
441452 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -451,8 +462,7 @@ def attribute(
451462 ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]: ...
452463
453464 @typing .overload
454- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455- # arguments of overload defined on line `439`.
465+ @log_usage ()
456466 def attribute (
457467 self ,
458468 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -467,8 +477,6 @@ def attribute(
467477 ) -> Union [Tensor , Tuple [Tensor , ...]]: ...
468478
469479 @log_usage ()
470- # pyre-fixme[43]: This definition does not have the same decorators as the
471- # preceding overload(s).
472480 def attribute (
473481 self ,
474482 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -654,7 +662,7 @@ def attribute(
654662 ) = DeepLiftShap ._expand_inputs_baselines_targets (
655663 self , baselines , inputs , target , additional_forward_args
656664 )
657- attributions = LayerDeepLift .attribute .__wrapped__ ( # type: ignore
665+ attribs_layer_deeplift = LayerDeepLift .attribute .__wrapped__ ( # type: ignore
658666 self ,
659667 exp_inp ,
660668 exp_base ,
@@ -667,8 +675,12 @@ def attribute(
667675 attribute_to_layer_input = attribute_to_layer_input ,
668676 custom_attribution_func = custom_attribution_func ,
669677 )
678+ delta : Tensor
679+ attributions : Union [Tensor , Tuple [Tensor , ...]]
670680 if return_convergence_delta :
671- attributions , delta = attributions
681+ attributions , delta = attribs_layer_deeplift
682+ else :
683+ attributions = attribs_layer_deeplift
672684 if isinstance (attributions , tuple ):
673685 attributions = tuple (
674686 DeepLiftShap ._compute_mean_across_baselines (
@@ -681,15 +693,17 @@ def attribute(
681693 self , inp_bsz , base_bsz , attributions
682694 )
683695 if return_convergence_delta :
684- # pyre-fixme[61]: `delta` is undefined, or not always defined.
685696 return attributions , delta
686697 else :
687- # pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
688- # typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
689- # but got `Union[tuple[Tensor], Tensor]`.
690- return attributions
698+ return cast (
699+ Union [
700+ Tensor ,
701+ Tuple [Tensor , ...],
702+ Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ],
703+ ],
704+ attributions ,
705+ )
691706
692707 @property
693- # pyre-fixme[3]: Return type must be annotated.
694708 def multiplies_by_inputs (self ) -> bool :
695709 return self ._multiply_by_inputs
0 commit comments