Skip to content

Commit fe0141b

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer deeplift pyre fixme issues
Differential Revision: D67705583
1 parent 80f0832 commit fe0141b

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union
5+
from typing import (
6+
Any,
7+
Callable,
8+
cast,
9+
Dict,
10+
List,
11+
Literal,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
Union,
16+
)
617

718
import torch
819
from captum._utils.common import (
@@ -96,6 +107,7 @@ def __init__(
96107

97108
# Ignoring mypy error for inconsistent signature with DeepLift
98109
@typing.overload # type: ignore
110+
@log_usage()
99111
def attribute(
100112
self,
101113
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -110,6 +122,7 @@ def attribute(
110122
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
111123

112124
@typing.overload
125+
@log_usage()
113126
def attribute(
114127
self,
115128
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -123,8 +136,6 @@ def attribute(
123136
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
124137

125138
@log_usage()
126-
# pyre-fixme[43]: This definition does not have the same decorators as the
127-
# preceding overload(s).
128139
def attribute(
129140
self,
130141
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -321,8 +332,9 @@ def attribute(
321332
additional_forward_args,
322333
)
323334

324-
# pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
325-
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
335+
def chunk_output_fn(
336+
out: TensorOrTupleOfTensorsGeneric,
337+
) -> Sequence[Union[Tensor, List[Tensor]]]:
326338
if isinstance(out, Tensor):
327339
return out.chunk(2)
328340
return tuple(out_sub.chunk(2) for out_sub in out)
@@ -434,8 +446,7 @@ def __init__(
434446

435447
# Ignoring mypy error for inconsistent signature with DeepLiftShap
436448
@typing.overload # type: ignore
437-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
438-
# arguments of overload defined on line `453`.
449+
@log_usage()
439450
def attribute(
440451
self,
441452
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -451,8 +462,7 @@ def attribute(
451462
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
452463

453464
@typing.overload
454-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455-
# arguments of overload defined on line `439`.
465+
@log_usage()
456466
def attribute(
457467
self,
458468
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -467,8 +477,6 @@ def attribute(
467477
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
468478

469479
@log_usage()
470-
# pyre-fixme[43]: This definition does not have the same decorators as the
471-
# preceding overload(s).
472480
def attribute(
473481
self,
474482
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -654,7 +662,7 @@ def attribute(
654662
) = DeepLiftShap._expand_inputs_baselines_targets(
655663
self, baselines, inputs, target, additional_forward_args
656664
)
657-
attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
665+
attribs_layer_deeplift = LayerDeepLift.attribute.__wrapped__( # type: ignore
658666
self,
659667
exp_inp,
660668
exp_base,
@@ -667,8 +675,12 @@ def attribute(
667675
attribute_to_layer_input=attribute_to_layer_input,
668676
custom_attribution_func=custom_attribution_func,
669677
)
678+
delta: Tensor
679+
attributions: Union[Tensor, Tuple[Tensor, ...]]
670680
if return_convergence_delta:
671-
attributions, delta = attributions
681+
attributions, delta = attribs_layer_deeplift
682+
else:
683+
attributions = attribs_layer_deeplift
672684
if isinstance(attributions, tuple):
673685
attributions = tuple(
674686
DeepLiftShap._compute_mean_across_baselines(
@@ -681,15 +693,17 @@ def attribute(
681693
self, inp_bsz, base_bsz, attributions
682694
)
683695
if return_convergence_delta:
684-
# pyre-fixme[61]: `delta` is undefined, or not always defined.
685696
return attributions, delta
686697
else:
687-
# pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
688-
# typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
689-
# but got `Union[tuple[Tensor], Tensor]`.
690-
return attributions
698+
return cast(
699+
Union[
700+
Tensor,
701+
Tuple[Tensor, ...],
702+
Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor],
703+
],
704+
attributions,
705+
)
691706

692707
@property
693-
# pyre-fixme[3]: Return type must be annotated.
694708
def multiplies_by_inputs(self) -> bool:
695709
return self._multiply_by_inputs

0 commit comments

Comments
 (0)