Skip to content

Commit e5d6d6b

Browse files
Zach Carmichaelfacebook-github-bot
authored andcommitted
Address instances of "Overloaded function signature x will never be matched" + minor typing fixes
Summary: Many overloads produced false positives or required changing order due to mypy breaking ties by picking the first matching variant (https://mypy.readthedocs.io/en/stable/more_types.html). This fixes or suppresses these errors. Created T204932142 to address Literal-related issues. Differential Revision: D64517613
1 parent a510bf6 commit e5d6d6b

File tree

13 files changed

+150
-137
lines changed

13 files changed

+150
-137
lines changed

captum/_utils/common.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ def safe_div(
7373
@typing.overload
7474
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
7575
# is incompatible with the return type of the implementation (`bool`).
76-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
76+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
7777
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
78-
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
78+
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
7979

8080

8181
@typing.overload
8282
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
8383
# is incompatible with the return type of the implementation (`bool`).
84-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
84+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
8585
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
86-
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
86+
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8787

8888

8989
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
@@ -277,7 +277,7 @@ def _format_additional_forward_args(
277277

278278

279279
@overload
280-
def _format_additional_forward_args(
280+
def _format_additional_forward_args( # type: ignore
281281
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
282282
additional_forward_args: Any,
283283
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
@@ -780,10 +780,10 @@ def _reduce_list(
780780
"""
781781
assert len(val_list) > 0, "Cannot reduce empty list!"
782782
if isinstance(val_list[0], torch.Tensor):
783-
# pyre-fixme[16]: `bool` has no attribute `device`.
784-
first_device = val_list[0].device
785-
# pyre-fixme[16]: `bool` has no attribute `to`.
786-
return red_func([elem.to(first_device) for elem in val_list])
783+
first_device = cast(Tensor, val_list[0]).device
784+
return red_func(
785+
[elem.to(first_device) for elem in cast(List[Tensor], val_list)]
786+
)
787787
elif isinstance(val_list[0], bool):
788788
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
789789
return any(val_list)

captum/_utils/gradient.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,33 +159,34 @@ def _neuron_gradients(
159159

160160
@typing.overload
161161
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
162-
# possible arguments of overload defined on line `158`.
162+
# possible arguments of overload defined on line `170`.
163163
def _forward_layer_eval(
164164
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
165165
forward_fn: Callable,
166166
inputs: Union[Tensor, Tuple[Tensor, ...]],
167-
layer: Module,
167+
layer: List[Module],
168168
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
169169
additional_forward_args: Any = None,
170170
device_ids: Union[None, List[int]] = None,
171171
attribute_to_layer_input: bool = False,
172172
grad_enabled: bool = False,
173-
) -> Tuple[Tensor, ...]: ...
173+
) -> List[Tuple[Tensor, ...]]: ...
174174

175175

176176
@typing.overload
177177
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
178-
# possible arguments of overload defined on line `170`.
178+
# possible arguments of overload defined on line `158`.
179179
def _forward_layer_eval(
180180
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
181181
forward_fn: Callable,
182182
inputs: Union[Tensor, Tuple[Tensor, ...]],
183-
layer: List[Module],
183+
layer: Module,
184+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
184185
additional_forward_args: Any = None,
185186
device_ids: Union[None, List[int]] = None,
186187
attribute_to_layer_input: bool = False,
187188
grad_enabled: bool = False,
188-
) -> List[Tuple[Tensor, ...]]: ...
189+
) -> Tuple[Tensor, ...]: ...
189190

190191

191192
def _forward_layer_eval(
@@ -434,34 +435,34 @@ def _forward_layer_eval_with_neuron_grads(
434435

435436
@typing.overload
436437
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
437-
# not accept all possible arguments of overload defined on line `392`.
438+
# not accept all possible arguments of overload defined on line `405`.
438439
def _forward_layer_eval_with_neuron_grads(
439440
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
440441
forward_fn: Callable,
441442
inputs: Union[Tensor, Tuple[Tensor, ...]],
442-
layer: Module,
443+
layer: List[Module],
443444
additional_forward_args: Any = None,
444445
gradient_neuron_selector: None = None,
445446
grad_enabled: bool = False,
446447
device_ids: Union[None, List[int]] = None,
447448
attribute_to_layer_input: bool = False,
448-
) -> Tuple[Tensor, ...]: ...
449+
) -> List[Tuple[Tensor, ...]]: ...
449450

450451

451452
@typing.overload
452453
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
453-
# not accept all possible arguments of overload defined on line `405`.
454+
# not accept all possible arguments of overload defined on line `392`.
454455
def _forward_layer_eval_with_neuron_grads(
455456
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
456457
forward_fn: Callable,
457458
inputs: Union[Tensor, Tuple[Tensor, ...]],
458-
layer: List[Module],
459+
layer: Module,
459460
additional_forward_args: Any = None,
460461
gradient_neuron_selector: None = None,
461462
grad_enabled: bool = False,
462463
device_ids: Union[None, List[int]] = None,
463464
attribute_to_layer_input: bool = False,
464-
) -> List[Tuple[Tensor, ...]]: ...
465+
) -> Tuple[Tensor, ...]: ...
465466

466467

467468
def _forward_layer_eval_with_neuron_grads(

captum/attr/_core/deep_lift.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,36 @@ def __init__(
118118

119119
@typing.overload
120120
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
121-
# arguments of overload defined on line `120`.
121+
# arguments of overload defined on line `131`.
122122
def attribute(
123123
self,
124124
inputs: TensorOrTupleOfTensorsGeneric,
125125
baselines: BaselineType = None,
126126
target: TargetType = None,
127-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
128127
additional_forward_args: Any = None,
129-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
130-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
128+
*,
129+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
131130
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
132-
return_convergence_delta: Literal[False] = False,
131+
return_convergence_delta: Literal[True],
133132
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
134-
) -> TensorOrTupleOfTensorsGeneric: ...
133+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
135134

136135
@typing.overload
137136
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
138-
# arguments of overload defined on line `131`.
137+
# arguments of overload defined on line `120`.
139138
def attribute(
140139
self,
141140
inputs: TensorOrTupleOfTensorsGeneric,
142141
baselines: BaselineType = None,
143142
target: TargetType = None,
143+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
144144
additional_forward_args: Any = None,
145-
*,
146-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
145+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
146+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
147147
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
148-
return_convergence_delta: Literal[True],
148+
return_convergence_delta: Literal[False] = False,
149149
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
150-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
150+
) -> TensorOrTupleOfTensorsGeneric: ...
151151

152152
@log_usage()
153153
def attribute( # type: ignore
@@ -636,7 +636,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
636636
# DeepLiftShap.attribute, so we ignore typing here
637637
@typing.overload # type: ignore
638638
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
639-
# arguments of overload defined on line `584`.
639+
# arguments of overload defined on line `597`.
640640
def attribute(
641641
self,
642642
inputs: TensorOrTupleOfTensorsGeneric,
@@ -646,30 +646,31 @@ def attribute(
646646
target: TargetType = None,
647647
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
648648
additional_forward_args: Any = None,
649-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
650-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
649+
*,
650+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
651651
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
652-
return_convergence_delta: Literal[False] = False,
652+
return_convergence_delta: Literal[True],
653653
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
654-
) -> TensorOrTupleOfTensorsGeneric: ...
654+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
655655

656656
@typing.overload
657657
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
658-
# arguments of overload defined on line `597`.
658+
# arguments of overload defined on line `584`.
659659
def attribute(
660660
self,
661661
inputs: TensorOrTupleOfTensorsGeneric,
662662
baselines: Union[
663663
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
664664
],
665665
target: TargetType = None,
666+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
666667
additional_forward_args: Any = None,
667-
*,
668-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
668+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
669+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
669670
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
670-
return_convergence_delta: Literal[True],
671+
return_convergence_delta: Literal[False] = False,
671672
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
672-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
673+
) -> TensorOrTupleOfTensorsGeneric: ...
673674

674675
@log_usage()
675676
def attribute( # type: ignore

captum/attr/_core/integrated_gradients.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
# a tuple with both attributions and deltas.
8282
@typing.overload
8383
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
84-
# arguments of overload defined on line `82`.
84+
# arguments of overload defined on line `95`.
8585
def attribute(
8686
self,
8787
inputs: TensorOrTupleOfTensorsGeneric,
@@ -92,29 +92,30 @@ def attribute(
9292
n_steps: int = 50,
9393
method: str = "gausslegendre",
9494
internal_batch_size: Union[None, int] = None,
95-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
96-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
95+
*,
96+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
9797
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
98-
return_convergence_delta: Literal[False] = False,
99-
) -> TensorOrTupleOfTensorsGeneric: ...
98+
return_convergence_delta: Literal[True],
99+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
100100

101101
@typing.overload
102102
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
103-
# arguments of overload defined on line `95`.
103+
# arguments of overload defined on line `82`.
104104
def attribute(
105105
self,
106106
inputs: TensorOrTupleOfTensorsGeneric,
107107
baselines: BaselineType = None,
108108
target: TargetType = None,
109+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
109110
additional_forward_args: Any = None,
110111
n_steps: int = 50,
111112
method: str = "gausslegendre",
112113
internal_batch_size: Union[None, int] = None,
113-
*,
114-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
114+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
115+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
115116
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116-
return_convergence_delta: Literal[True],
117-
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
117+
return_convergence_delta: Literal[False] = False,
118+
) -> TensorOrTupleOfTensorsGeneric: ...
118119

119120
@log_usage()
120121
def attribute( # type: ignore

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,40 +102,41 @@ def __init__(
102102
# Ignoring mypy error for inconsistent signature with DeepLift
103103
@typing.overload # type: ignore
104104
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
105-
# arguments of overload defined on line `104`.
105+
# arguments of overload defined on line `117`.
106106
def attribute(
107107
self,
108108
inputs: Union[Tensor, Tuple[Tensor, ...]],
109109
baselines: BaselineType = None,
110110
target: TargetType = None,
111111
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
112112
additional_forward_args: Any = None,
113-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
114-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
113+
*,
114+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
115115
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116-
return_convergence_delta: Literal[False] = False,
116+
return_convergence_delta: Literal[True],
117117
attribute_to_layer_input: bool = False,
118118
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
119119
grad_kwargs: Optional[Dict[str, Any]] = None,
120-
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
120+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
121121

122122
@typing.overload
123123
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
124-
# arguments of overload defined on line `117`.
124+
# arguments of overload defined on line `104`.
125125
def attribute(
126126
self,
127127
inputs: Union[Tensor, Tuple[Tensor, ...]],
128128
baselines: BaselineType = None,
129129
target: TargetType = None,
130+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
130131
additional_forward_args: Any = None,
131-
*,
132-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
132+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
133+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
133134
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
134-
return_convergence_delta: Literal[True],
135+
return_convergence_delta: Literal[False] = False,
135136
attribute_to_layer_input: bool = False,
136137
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
137138
grad_kwargs: Optional[Dict[str, Any]] = None,
138-
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
139+
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
139140

140141
@log_usage()
141142
# pyre-fixme[43]: This definition does not have the same decorators as the
@@ -452,7 +453,7 @@ def __init__(
452453
# Ignoring mypy error for inconsistent signature with DeepLiftShap
453454
@typing.overload # type: ignore
454455
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
455-
# arguments of overload defined on line `439`.
456+
# arguments of overload defined on line `453`.
456457
def attribute(
457458
self,
458459
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -462,32 +463,33 @@ def attribute(
462463
target: TargetType = None,
463464
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
464465
additional_forward_args: Any = None,
465-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
466-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
466+
*,
467+
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
467468
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
468-
return_convergence_delta: Literal[False] = False,
469+
return_convergence_delta: Literal[True],
469470
attribute_to_layer_input: bool = False,
470471
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
471-
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
472+
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
472473

473474
@typing.overload
474475
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
475-
# arguments of overload defined on line `453`.
476+
# arguments of overload defined on line `439`.
476477
def attribute(
477478
self,
478479
inputs: Union[Tensor, Tuple[Tensor, ...]],
479480
baselines: Union[
480481
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
481482
],
482483
target: TargetType = None,
484+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
483485
additional_forward_args: Any = None,
484-
*,
485-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
486+
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
487+
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
486488
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
487-
return_convergence_delta: Literal[True],
489+
return_convergence_delta: Literal[False] = False,
488490
attribute_to_layer_input: bool = False,
489491
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
490-
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
492+
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
491493

492494
@log_usage()
493495
# pyre-fixme[43]: This definition does not have the same decorators as the

0 commit comments

Comments
 (0)