Skip to content

Commit 87fb8ea

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in InputXGradient (#1397)
Summary: Initial work on fixing Pyre errors in InputXGradient Reviewed By: csauper Differential Revision: D64677348
1 parent 73fb4f2 commit 87fb8ea

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

captum/attr/_core/input_x_gradient.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable
4+
from typing import Callable
55

66
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
77
from captum._utils.gradient import (
@@ -11,6 +11,7 @@
1111
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
1212
from captum.attr._utils.attribution import GradientAttribution
1313
from captum.log import log_usage
14+
from torch import Tensor
1415

1516

1617
class InputXGradient(GradientAttribution):
@@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution):
2021
https://arxiv.org/abs/1605.01713
2122
"""
2223

23-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
24-
def __init__(self, forward_func: Callable) -> None:
24+
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
2525
r"""
2626
Args:
2727
@@ -35,8 +35,7 @@ def attribute(
3535
self,
3636
inputs: TensorOrTupleOfTensorsGeneric,
3737
target: TargetType = None,
38-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
39-
additional_forward_args: Any = None,
38+
additional_forward_args: object = None,
4039
) -> TensorOrTupleOfTensorsGeneric:
4140
r"""
4241
Args:
@@ -113,28 +112,20 @@ def attribute(
113112
"""
114113
# Keeps track whether original input is a tuple or not before
115114
# converting it into a tuple.
116-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
117-
# `TensorOrTupleOfTensorsGeneric`.
118115
is_inputs_tuple = _is_tuple(inputs)
119116

120-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
121-
# `Tuple[Tensor, ...]`.
122-
inputs = _format_tensor_into_tuples(inputs)
123-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
124-
# `TensorOrTupleOfTensorsGeneric`.
125-
gradient_mask = apply_gradient_requirements(inputs)
117+
inputs_tuple = _format_tensor_into_tuples(inputs)
118+
gradient_mask = apply_gradient_requirements(inputs_tuple)
126119

127120
gradients = self.gradient_func(
128-
self.forward_func, inputs, target, additional_forward_args
121+
self.forward_func, inputs_tuple, target, additional_forward_args
129122
)
130123

131124
attributions = tuple(
132-
input * gradient for input, gradient in zip(inputs, gradients)
125+
input * gradient for input, gradient in zip(inputs_tuple, gradients)
133126
)
134127

135-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
136-
# `TensorOrTupleOfTensorsGeneric`.
137-
undo_gradient_requirements(inputs, gradient_mask)
128+
undo_gradient_requirements(inputs_tuple, gradient_mask)
138129
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
139130
# `Tuple[Tensor, ...]`.
140131
return _format_output(is_inputs_tuple, attributions)

0 commit comments

Comments
 (0)