Skip to content

Commit 0f7ea16

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Add additional overload signatures for shared methods to resolve pyre errors (#1406)
Summary: Pull Request resolved: #1406 Add a few additional overload signatures to shared methods for resolving pyre errors Also remove separate cases for typing Literal since the split was necessary due to previous support for Python < 3.8 Reviewed By: csauper Differential Revision: D64677349 fbshipit-source-id: b4ad77e2b57d6769844583541c2cb3a0c377519d
1 parent e63d39f commit 0f7ea16

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

captum/_utils/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
8686
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8787

8888

89+
@typing.overload
90+
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
91+
92+
8993
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
9094
return isinstance(inputs, tuple)
9195

captum/_utils/typing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,11 @@
22

33
# pyre-strict
44

5-
from typing import (
6-
List,
7-
Optional,
8-
overload,
9-
Protocol,
10-
Tuple,
11-
TYPE_CHECKING,
12-
TypeVar,
13-
Union,
14-
)
5+
from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union
156

167
from torch import Tensor
178
from torch.nn import Module
189

19-
if TYPE_CHECKING:
20-
from typing import Literal
21-
else:
22-
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}
23-
2410
TensorOrTupleOfTensorsGeneric = TypeVar(
2511
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
2612
)

captum/attr/_utils/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def _format_input_baseline( # type: ignore
8282
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ...
8383

8484

85+
@typing.overload
86+
def _format_input_baseline( # type: ignore
87+
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType
88+
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ...
89+
90+
8591
def _format_input_baseline(
8692
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType
8793
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]:
@@ -236,6 +242,21 @@ def _compute_conv_delta_and_format_attrs(
236242
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ...
237243

238244

245+
@typing.overload
246+
def _compute_conv_delta_and_format_attrs(
247+
attr_algo: "GradientAttribution",
248+
return_convergence_delta: bool,
249+
attributions: Tuple[Tensor, ...],
250+
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]],
251+
end_point: Union[Tensor, Tuple[Tensor, ...]],
252+
additional_forward_args: Any,
253+
target: TargetType,
254+
is_inputs_tuple: bool = False,
255+
) -> Union[
256+
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
257+
]: ...
258+
259+
239260
# FIXME: GradientAttribution is provided as a string due to a circular import.
240261
# This should be fixed when common is refactored into separate files.
241262
def _compute_conv_delta_and_format_attrs(

0 commit comments

Comments
 (0)