@@ -321,8 +321,9 @@ def attribute(
321321 additional_forward_args ,
322322 )
323323
324- # pyre-fixme[24]: Generic type `Sequence` expects 1 type parameter.
325- def chunk_output_fn (out : TensorOrTupleOfTensorsGeneric ) -> Sequence :
324+ def chunk_output_fn (
325+ out : TensorOrTupleOfTensorsGeneric ,
326+ ) -> Sequence [Union [Tensor , Sequence [Tensor ]]]:
326327 if isinstance (out , Tensor ):
327328 return out .chunk (2 )
328329 return tuple (out_sub .chunk (2 ) for out_sub in out )
@@ -434,8 +435,6 @@ def __init__(
434435
435436 # Ignoring mypy error for inconsistent signature with DeepLiftShap
436437 @typing .overload # type: ignore
437- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
438- # arguments of overload defined on line `453`.
439438 def attribute (
440439 self ,
441440 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -450,9 +449,7 @@ def attribute(
450449 custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
451450 ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]: ...
452451
453- @typing .overload
454- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455- # arguments of overload defined on line `439`.
452+ @typing .overload # type: ignore
456453 def attribute (
457454 self ,
458455 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -654,7 +651,7 @@ def attribute(
654651 ) = DeepLiftShap ._expand_inputs_baselines_targets (
655652 self , baselines , inputs , target , additional_forward_args
656653 )
657- attributions = LayerDeepLift .attribute .__wrapped__ ( # type: ignore
654+ attribs_layer_deeplift = LayerDeepLift .attribute .__wrapped__ ( # type: ignore
658655 self ,
659656 exp_inp ,
660657 exp_base ,
@@ -667,8 +664,12 @@ def attribute(
667664 attribute_to_layer_input = attribute_to_layer_input ,
668665 custom_attribution_func = custom_attribution_func ,
669666 )
667+ delta : Tensor
668+ attributions : Union [Tensor , Tuple [Tensor , ...]]
670669 if return_convergence_delta :
671- attributions , delta = attributions
670+ attributions , delta = attribs_layer_deeplift
671+ else :
672+ attributions = attribs_layer_deeplift
672673 if isinstance (attributions , tuple ):
673674 attributions = tuple (
674675 DeepLiftShap ._compute_mean_across_baselines (
@@ -681,15 +682,17 @@ def attribute(
681682 self , inp_bsz , base_bsz , attributions
682683 )
683684 if return_convergence_delta :
684- # pyre-fixme[61]: `delta` is undefined, or not always defined.
685685 return attributions , delta
686686 else :
687- # pyre-fixme[7]: Expected `Union[Tuple[Union[Tensor,
688- # typing.Tuple[Tensor, ...]], Tensor], Tensor, typing.Tuple[Tensor, ...]]`
689- # but got `Union[tuple[Tensor], Tensor]`.
690- return attributions
687+ return cast (
688+ Union [
689+ Tensor ,
690+ Tuple [Tensor , ...],
691+ Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ],
692+ ],
693+ attributions ,
694+ )
691695
692696 @property
693- # pyre-fixme[3]: Return type must be annotated.
694697 def multiplies_by_inputs (self ) -> bool :
695698 return self ._multiply_by_inputs
0 commit comments