22
33# pyre-strict
44import typing
5- from typing import Any , Callable , cast , Dict , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , cast , Dict , Literal , Optional , Sequence , Tuple , Union
66
77import torch
88from captum ._utils .common import (
1313 ExpansionTypes ,
1414)
1515from captum ._utils .gradient import compute_layer_gradients_and_eval
16- from captum ._utils .typing import (
17- BaselineType ,
18- Literal ,
19- TargetType ,
20- TensorOrTupleOfTensorsGeneric ,
21- )
16+ from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
2217from captum .attr ._core .deep_lift import DeepLift , DeepLiftShap
2318from captum .attr ._utils .attribution import LayerAttribution
2419from captum .attr ._utils .common import (
@@ -101,8 +96,6 @@ def __init__(
10196
10297 # Ignoring mypy error for inconsistent signature with DeepLift
10398 @typing .overload # type: ignore
104- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
105- # arguments of overload defined on line `117`.
10699 def attribute (
107100 self ,
108101 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -111,27 +104,20 @@ def attribute(
111104 # pyre-fixme[2]: Parameter annotation cannot be `Any`.
112105 additional_forward_args : Any = None ,
113106 * ,
114- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
115- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116107 return_convergence_delta : Literal [True ],
117108 attribute_to_layer_input : bool = False ,
118109 custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
119110 grad_kwargs : Optional [Dict [str , Any ]] = None ,
120111 ) -> Tuple [Union [Tensor , Tuple [Tensor , ...]], Tensor ]: ...
121112
122113 @typing .overload
123- # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
124- # arguments of overload defined on line `104`.
125114 def attribute (
126115 self ,
127116 inputs : Union [Tensor , Tuple [Tensor , ...]],
128117 baselines : BaselineType = None ,
129118 target : TargetType = None ,
130119 # pyre-fixme[2]: Parameter annotation cannot be `Any`.
131120 additional_forward_args : Any = None ,
132- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
133- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
134- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
135121 return_convergence_delta : Literal [False ] = False ,
136122 attribute_to_layer_input : bool = False ,
137123 custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
@@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
382368 inputs ,
383369 additional_forward_args ,
384370 target ,
385- # pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
386- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
387371 cast (Union [Literal [True ], Literal [False ]], len (attributions ) > 1 ),
388372 )
389373
@@ -464,8 +448,6 @@ def attribute(
464448 # pyre-fixme[2]: Parameter annotation cannot be `Any`.
465449 additional_forward_args : Any = None ,
466450 * ,
467- # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
468- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
469451 return_convergence_delta : Literal [True ],
470452 attribute_to_layer_input : bool = False ,
471453 custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
@@ -483,9 +465,6 @@ def attribute(
483465 target : TargetType = None ,
484466 # pyre-fixme[2]: Parameter annotation cannot be `Any`.
485467 additional_forward_args : Any = None ,
486- # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
487- # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
488- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
489468 return_convergence_delta : Literal [False ] = False ,
490469 attribute_to_layer_input : bool = False ,
491470 custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
@@ -686,10 +665,6 @@ def attribute(
686665 target = exp_target ,
687666 additional_forward_args = exp_addit_args ,
688667 return_convergence_delta = cast (
689- # pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid
690- # type.
691- # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take
692- # parameters.
693668 Literal [True , False ],
694669 return_convergence_delta ,
695670 ),
0 commit comments