Skip to content

Commit 652dc6e

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Correct remaining typing.Literal imports (#1412)
Summary: Change remaining imports of `Literal` to be from `typing` library Reviewed By: vivekmig Differential Revision: D64807610
1 parent e2c06ed commit 652dc6e

File tree

12 files changed

+42
-143
lines changed

12 files changed

+42
-143
lines changed

captum/_utils/common.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55
from enum import Enum
66
from functools import reduce
77
from inspect import signature
8-
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
cast,
12+
Dict,
13+
List,
14+
Literal,
15+
overload,
16+
Sequence,
17+
Tuple,
18+
Union,
19+
)
920

1021
import numpy as np
1122
import torch
1223
from captum._utils.typing import (
1324
BaselineType,
14-
Literal,
1525
TargetType,
1626
TensorOrTupleOfTensorsGeneric,
1727
TupleOrTensorOrBoolGeneric,
@@ -71,23 +81,17 @@ def safe_div(
7181

7282

7383
@typing.overload
74-
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
75-
# is incompatible with the return type of the implementation (`bool`).
76-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
77-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7884
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
7985

8086

8187
@typing.overload
82-
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
83-
# is incompatible with the return type of the implementation (`bool`).
84-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
85-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
8688
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
8789

8890

8991
@typing.overload
90-
def _is_tuple(inputs: TensorOrTupleOfTensorsGeneric) -> bool: ...
92+
def _is_tuple(
93+
inputs: TensorOrTupleOfTensorsGeneric,
94+
) -> bool: ... # type: ignore
9195

9296

9397
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
@@ -480,22 +484,14 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
480484

481485

482486
@typing.overload
483-
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
484-
# possible arguments of overload defined on line `449`.
485487
def _format_output(
486-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
487-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
488488
is_inputs_tuple: Literal[True],
489489
output: Tuple[Tensor, ...],
490490
) -> Tuple[Tensor, ...]: ...
491491

492492

493493
@typing.overload
494-
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
495-
# possible arguments of overload defined on line `455`.
496494
def _format_output(
497-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
498-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
499495
is_inputs_tuple: Literal[False],
500496
output: Tuple[Tensor, ...],
501497
) -> Tensor: ...
@@ -526,22 +522,14 @@ def _format_output(
526522

527523

528524
@typing.overload
529-
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
530-
# possible arguments of overload defined on line `483`.
531525
def _format_outputs(
532-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
533-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
534526
is_multiple_inputs: Literal[False],
535527
outputs: List[Tuple[Tensor, ...]],
536528
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
537529

538530

539531
@typing.overload
540-
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
541-
# possible arguments of overload defined on line `489`.
542532
def _format_outputs(
543-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
544-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
545533
is_multiple_inputs: Literal[True],
546534
outputs: List[Tuple[Tensor, ...]],
547535
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...

captum/_utils/gradient.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import typing
66
import warnings
77
from collections import defaultdict
8-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
cast,
12+
Dict,
13+
List,
14+
Literal,
15+
Optional,
16+
Sequence,
17+
Tuple,
18+
Union,
19+
)
920

1021
import torch
1122
from captum._utils.common import (
@@ -16,7 +27,6 @@
1627
)
1728
from captum._utils.sample_gradient import SampleGradientWrapper
1829
from captum._utils.typing import (
19-
Literal,
2030
ModuleOrModuleList,
2131
TargetType,
2232
TensorOrTupleOfTensorsGeneric,
@@ -226,9 +236,6 @@ def _forward_layer_distributed_eval(
226236
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
227237
additional_forward_args: Any = None,
228238
attribute_to_layer_input: bool = False,
229-
# pyre-fixme[9]: forward_hook_with_return has type `Literal[]`; used as `bool`.
230-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
231-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
232239
forward_hook_with_return: Literal[False] = False,
233240
require_layer_grads: bool = False,
234241
) -> Dict[Module, Dict[device, Tuple[Tensor, ...]]]: ...
@@ -246,8 +253,6 @@ def _forward_layer_distributed_eval(
246253
additional_forward_args: Any = None,
247254
attribute_to_layer_input: bool = False,
248255
*,
249-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
250-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
251256
forward_hook_with_return: Literal[True],
252257
require_layer_grads: bool = False,
253258
) -> Tuple[Dict[Module, Dict[device, Tuple[Tensor, ...]]], Tensor]: ...
@@ -675,7 +680,6 @@ def compute_layer_gradients_and_eval(
675680
target_ind=target_ind,
676681
additional_forward_args=additional_forward_args,
677682
attribute_to_layer_input=attribute_to_layer_input,
678-
# pyre-fixme[6]: For 7th argument expected `Literal[]` but got `bool`.
679683
forward_hook_with_return=True,
680684
require_layer_grads=True,
681685
)

captum/_utils/progress.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import sys
66
import warnings
77
from time import time
8-
from typing import Any, cast, Iterable, Optional, Sized, TextIO
9-
10-
from captum._utils.typing import Literal
8+
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO
119

1210
try:
1311
from tqdm.auto import tqdm
@@ -75,10 +73,7 @@ def __enter__(self) -> "NullProgress":
7573
return self
7674

7775
# pyre-fixme[2]: Parameter must be annotated.
78-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
79-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
8076
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
81-
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
8277
return False
8378

8479
# pyre-fixme[3]: Return type must be annotated.
@@ -139,11 +134,8 @@ def __enter__(self) -> "SimpleProgress":
139134
return self
140135

141136
# pyre-fixme[2]: Parameter must be annotated.
142-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
143-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
144137
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
145138
self.close()
146-
# pyre-fixme[7]: Expected `Literal[]` but got `bool`.
147139
return False
148140

149141
# pyre-fixme[3]: Return type must be annotated.

captum/attr/_core/layer/layer_conductance.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -12,7 +12,7 @@
1212
_format_output,
1313
)
1414
from captum._utils.gradient import compute_layer_gradients_and_eval
15-
from captum._utils.typing import BaselineType, Literal, TargetType
15+
from captum._utils.typing import BaselineType, TargetType
1616
from captum.attr._utils.approximation_methods import approximation_parameters
1717
from captum.attr._utils.attribution import GradientAttribution, LayerAttribution
1818
from captum.attr._utils.batching import _batch_attribution
@@ -86,8 +86,6 @@ def attribute(
8686
method: str = "gausslegendre",
8787
internal_batch_size: Union[None, int] = None,
8888
*,
89-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
90-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
9189
return_convergence_delta: Literal[True],
9290
attribute_to_layer_input: bool = False,
9391
grad_kwargs: Optional[Dict[str, Any]] = None,
@@ -105,9 +103,6 @@ def attribute(
105103
n_steps: int = 50,
106104
method: str = "gausslegendre",
107105
internal_batch_size: Union[None, int] = None,
108-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
109-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
110-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
111106
return_convergence_delta: Literal[False] = False,
112107
attribute_to_layer_input: bool = False,
113108
grad_kwargs: Optional[Dict[str, Any]] = None,

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -13,12 +13,7 @@
1313
ExpansionTypes,
1414
)
1515
from captum._utils.gradient import compute_layer_gradients_and_eval
16-
from captum._utils.typing import (
17-
BaselineType,
18-
Literal,
19-
TargetType,
20-
TensorOrTupleOfTensorsGeneric,
21-
)
16+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2217
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap
2318
from captum.attr._utils.attribution import LayerAttribution
2419
from captum.attr._utils.common import (
@@ -101,8 +96,6 @@ def __init__(
10196

10297
# Ignoring mypy error for inconsistent signature with DeepLift
10398
@typing.overload # type: ignore
104-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
105-
# arguments of overload defined on line `117`.
10699
def attribute(
107100
self,
108101
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -111,27 +104,20 @@ def attribute(
111104
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
112105
additional_forward_args: Any = None,
113106
*,
114-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
115-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
116107
return_convergence_delta: Literal[True],
117108
attribute_to_layer_input: bool = False,
118109
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
119110
grad_kwargs: Optional[Dict[str, Any]] = None,
120111
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
121112

122113
@typing.overload
123-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
124-
# arguments of overload defined on line `104`.
125114
def attribute(
126115
self,
127116
inputs: Union[Tensor, Tuple[Tensor, ...]],
128117
baselines: BaselineType = None,
129118
target: TargetType = None,
130119
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
131120
additional_forward_args: Any = None,
132-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
133-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
134-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
135121
return_convergence_delta: Literal[False] = False,
136122
attribute_to_layer_input: bool = False,
137123
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -382,8 +368,6 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
382368
inputs,
383369
additional_forward_args,
384370
target,
385-
# pyre-fixme[31]: Expression `Literal[False])]` is not a valid type.
386-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
387371
cast(Union[Literal[True], Literal[False]], len(attributions) > 1),
388372
)
389373

@@ -464,8 +448,6 @@ def attribute(
464448
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
465449
additional_forward_args: Any = None,
466450
*,
467-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
468-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
469451
return_convergence_delta: Literal[True],
470452
attribute_to_layer_input: bool = False,
471453
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -483,9 +465,6 @@ def attribute(
483465
target: TargetType = None,
484466
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
485467
additional_forward_args: Any = None,
486-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
487-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
488-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
489468
return_convergence_delta: Literal[False] = False,
490469
attribute_to_layer_input: bool = False,
491470
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -686,10 +665,6 @@ def attribute(
686665
target=exp_target,
687666
additional_forward_args=exp_addit_args,
688667
return_convergence_delta=cast(
689-
# pyre-fixme[31]: Expression `Literal[(True, False)]` is not a valid
690-
# type.
691-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take
692-
# parameters.
693668
Literal[True, False],
694669
return_convergence_delta,
695670
),

0 commit comments

Comments
 (0)