Skip to content

Commit fcc9d26

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer deeplift pyre fixme issues (#1470)
Summary: Pull Request resolved: #1470 Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67705583 fbshipit-source-id: ac80a67112f2abdb5356ee1e01ee1ec26afd2bf9
1 parent 5e7a7e7 commit fcc9d26

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,9 @@ def attribute(
321321
additional_forward_args,
322322
)
323323

324-
# pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
325-
def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
324+
def chunk_output_fn(
325+
out: TensorOrTupleOfTensorsGeneric,
326+
) -> Sequence[Union[Tensor, Sequence[Tensor]]]:
326327
if isinstance(out, Tensor):
327328
return out.chunk(2)
328329
return tuple(out_sub.chunk(2) for out_sub in out)
@@ -434,8 +435,6 @@ def __init__(
434435

435436
# Ignoring mypy error for inconsistent signature with DeepLiftShap
436437
@typing.overload # type: ignore
437-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
438-
# arguments of overload defined on line `453`.
439438
def attribute(
440439
self,
441440
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -450,9 +449,7 @@ def attribute(
450449
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
451450
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
452451

453-
@typing.overload
454-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455-
# arguments of overload defined on line `439`.
452+
@typing.overload # type: ignore
456453
def attribute(
457454
self,
458455
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -654,7 +651,7 @@ def attribute(
654651
) = DeepLiftShap._expand_inputs_baselines_targets(
655652
self, baselines, inputs, target, additional_forward_args
656653
)
657-
attributions = LayerDeepLift.attribute.__wrapped__( # type: ignore
654+
attribs_layer_deeplift = LayerDeepLift.attribute.__wrapped__( # type: ignore
658655
self,
659656
exp_inp,
660657
exp_base,
@@ -667,8 +664,12 @@ def attribute(
667664
attribute_to_layer_input=attribute_to_layer_input,
668665
custom_attribution_func=custom_attribution_func,
669666
)
667+
delta: Tensor
668+
attributions: Union[Tensor, Tuple[Tensor, ...]]
670669
if return_convergence_delta:
671-
attributions, delta = attributions
670+
attributions, delta = attribs_layer_deeplift
671+
else:
672+
attributions = attribs_layer_deeplift
672673
if isinstance(attributions, tuple):
673674
attributions = tuple(
674675
DeepLiftShap._compute_mean_across_baselines(
@@ -681,15 +682,17 @@ def attribute(
681682
self, inp_bsz, base_bsz, attributions
682683
)
683684
if return_convergence_delta:
684-
# pyre-fixme[61]: `delta` is undefined, or not always defined.
685685
return attributions, delta
686686
else:
687-
# pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
688-
# typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
689-
# but got `Union[tuple[Tensor], Tensor]`.
690-
return attributions
687+
return cast(
688+
Union[
689+
Tensor,
690+
Tuple[Tensor, ...],
691+
Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor],
692+
],
693+
attributions,
694+
)
691695

692696
@property
693-
# pyre-fixme[3]: Return type must be annotated.
694697
def multiplies_by_inputs(self) -> bool:
695698
return self._multiply_by_inputs

0 commit comments

Comments
 (0)