Skip to content

Commit 422c30d

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Guided Backprop
Summary: Initial work on fixing Pyre errors in Guided Backprop Differential Revision: D64677346
1 parent 28c8e4f commit 422c30d

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

captum/attr/_core/guided_backprop_deconvnet.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def attribute(
4545
self,
4646
inputs: TensorOrTupleOfTensorsGeneric,
4747
target: TargetType = None,
48-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
49-
additional_forward_args: Any = None,
48+
additional_forward_args: object = None,
5049
) -> TensorOrTupleOfTensorsGeneric:
5150
r"""
5251
Computes attribution by overriding relu gradients. Based on constructor
@@ -58,16 +57,10 @@ def attribute(
5857

5958
# Keeps track whether original input is a tuple or not before
6059
# converting it into a tuple.
61-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
62-
# `TensorOrTupleOfTensorsGeneric`.
6360
is_inputs_tuple = _is_tuple(inputs)
6461

65-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
66-
# `Tuple[Tensor, ...]`.
67-
inputs = _format_tensor_into_tuples(inputs)
68-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
69-
# `TensorOrTupleOfTensorsGeneric`.
70-
gradient_mask = apply_gradient_requirements(inputs)
62+
inputs_tuple = _format_tensor_into_tuples(inputs)
63+
gradient_mask = apply_gradient_requirements(inputs_tuple)
7164

7265
# set hooks for overriding ReLU gradients
7366
warnings.warn(
@@ -79,14 +72,12 @@ def attribute(
7972
self.model.apply(self._register_hooks)
8073

8174
gradients = self.gradient_func(
82-
self.forward_func, inputs, target, additional_forward_args
75+
self.forward_func, inputs_tuple, target, additional_forward_args
8376
)
8477
finally:
8578
self._remove_hooks()
8679

87-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
88-
# `TensorOrTupleOfTensorsGeneric`.
89-
undo_gradient_requirements(inputs, gradient_mask)
80+
undo_gradient_requirements(inputs_tuple, gradient_mask)
9081
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
9182
# `Tuple[Tensor, ...]`.
9283
return _format_output(is_inputs_tuple, gradients)
@@ -155,8 +146,7 @@ def attribute(
155146
self,
156147
inputs: TensorOrTupleOfTensorsGeneric,
157148
target: TargetType = None,
158-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
159-
additional_forward_args: Any = None,
149+
additional_forward_args: object = None,
160150
) -> TensorOrTupleOfTensorsGeneric:
161151
r"""
162152
Args:
@@ -265,8 +255,7 @@ def attribute(
265255
self,
266256
inputs: TensorOrTupleOfTensorsGeneric,
267257
target: TargetType = None,
268-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
269-
additional_forward_args: Any = None,
258+
additional_forward_args: object = None,
270259
) -> TensorOrTupleOfTensorsGeneric:
271260
r"""
272261
Args:

0 commit comments

Comments
 (0)