Skip to content

Commit dfaf02d

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Integrated Gradients (#1398)
Summary: Initial work on fixing Pyre errors in Integrated Gradients Reviewed By: csauper Differential Revision: D64677345
1 parent 54393ff commit dfaf02d

File tree

1 file changed

+14
-31
lines changed

1 file changed

+14
-31
lines changed

captum/attr/_core/integrated_gradients.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Any, Callable, List, Tuple, Union
5+
from typing import Any, Callable, List, Literal, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -12,12 +12,7 @@
1212
_format_output,
1313
_is_tuple,
1414
)
15-
from captum._utils.typing import (
16-
BaselineType,
17-
Literal,
18-
TargetType,
19-
TensorOrTupleOfTensorsGeneric,
20-
)
15+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
2116
from captum.attr._utils.approximation_methods import approximation_parameters
2217
from captum.attr._utils.attribution import GradientAttribution
2318
from captum.attr._utils.batching import _batch_attribution
@@ -49,8 +44,7 @@ class IntegratedGradients(GradientAttribution):
4944

5045
def __init__(
5146
self,
52-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
53-
forward_func: Callable,
47+
forward_func: Callable[..., Tensor],
5448
multiply_by_inputs: bool = True,
5549
) -> None:
5650
r"""
@@ -80,21 +74,16 @@ def __init__(
8074
# and when return_convergence_delta is True, the return type is
8175
# a tuple with both attributions and deltas.
8276
@typing.overload
83-
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
84-
# arguments of overload defined on line `95`.
8577
def attribute(
8678
self,
8779
inputs: TensorOrTupleOfTensorsGeneric,
8880
baselines: BaselineType = None,
8981
target: TargetType = None,
90-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
91-
additional_forward_args: Any = None,
82+
additional_forward_args: object = None,
9283
n_steps: int = 50,
9384
method: str = "gausslegendre",
9485
internal_batch_size: Union[None, int] = None,
9586
*,
96-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
97-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
9887
return_convergence_delta: Literal[True],
9988
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
10089

@@ -111,9 +100,6 @@ def attribute(
111100
n_steps: int = 50,
112101
method: str = "gausslegendre",
113102
internal_batch_size: Union[None, int] = None,
114-
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
115-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
116-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
117103
return_convergence_delta: Literal[False] = False,
118104
) -> TensorOrTupleOfTensorsGeneric: ...
119105

@@ -275,37 +261,35 @@ def attribute( # type: ignore
275261
"""
276262
# Keeps track whether original input is a tuple or not before
277263
# converting it into a tuple.
278-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
279-
# `TensorOrTupleOfTensorsGeneric`.
280264
is_inputs_tuple = _is_tuple(inputs)
281265

282266
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
283267
# `Tuple[Tensor, ...]`.
284-
inputs, baselines = _format_input_baseline(inputs, baselines)
268+
formatted_inputs, formatted_baselines = _format_input_baseline(
269+
inputs, baselines
270+
)
285271

286272
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
287273
# `TensorOrTupleOfTensorsGeneric`.
288-
_validate_input(inputs, baselines, n_steps, method)
274+
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)
289275

290276
if internal_batch_size is not None:
291-
num_examples = inputs[0].shape[0]
277+
num_examples = formatted_inputs[0].shape[0]
292278
attributions = _batch_attribution(
293279
self,
294280
num_examples,
295281
internal_batch_size,
296282
n_steps,
297-
inputs=inputs,
298-
baselines=baselines,
283+
inputs=formatted_inputs,
284+
baselines=formatted_baselines,
299285
target=target,
300286
additional_forward_args=additional_forward_args,
301287
method=method,
302288
)
303289
else:
304290
attributions = self._attribute(
305-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
306-
# got `TensorOrTupleOfTensorsGeneric`.
307-
inputs=inputs,
308-
baselines=baselines,
291+
inputs=formatted_inputs,
292+
baselines=formatted_baselines,
309293
target=target,
310294
additional_forward_args=additional_forward_args,
311295
n_steps=n_steps,
@@ -344,8 +328,7 @@ def _attribute(
344328
inputs: Tuple[Tensor, ...],
345329
baselines: Tuple[Union[Tensor, int, float], ...],
346330
target: TargetType = None,
347-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
348-
additional_forward_args: Any = None,
331+
additional_forward_args: object = None,
349332
n_steps: int = 50,
350333
method: str = "gausslegendre",
351334
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,

0 commit comments

Comments
 (0)