Skip to content

Commit 4a3ba86

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in GuidedGradCAM (#1396)
Summary: Initial work on fixing Pyre errors in Guided GradCAM Reviewed By: craymichael Differential Revision: D64677347
1 parent ed238fc commit 4a3ba86

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

captum/attr/_core/guided_grad_cam.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import warnings
5-
from typing import Any, List, Union
5+
from typing import List, Union
66

77
import torch
88
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
@@ -72,8 +72,7 @@ def attribute(
7272
self,
7373
inputs: TensorOrTupleOfTensorsGeneric,
7474
target: TargetType = None,
75-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
76-
additional_forward_args: Any = None,
75+
additional_forward_args: object = None,
7776
interpolate_mode: str = "nearest",
7877
attribute_to_layer_input: bool = False,
7978
) -> TensorOrTupleOfTensorsGeneric:
@@ -181,15 +180,11 @@ def attribute(
181180
>>> # attribution size matches input size, Nx3x32x32
182181
>>> attribution = guided_gc.attribute(input, 3)
183182
"""
184-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
185-
# `TensorOrTupleOfTensorsGeneric`.
186183
is_inputs_tuple = _is_tuple(inputs)
187-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
188-
# `Tuple[Tensor, ...]`.
189-
inputs = _format_tensor_into_tuples(inputs)
184+
inputs_tuple = _format_tensor_into_tuples(inputs)
190185
grad_cam_attr = self.grad_cam.attribute.__wrapped__(
191186
self.grad_cam, # self
192-
inputs=inputs,
187+
inputs=inputs_tuple,
193188
target=target,
194189
additional_forward_args=additional_forward_args,
195190
attribute_to_layer_input=attribute_to_layer_input,
@@ -204,20 +199,18 @@ def attribute(
204199

205200
guided_backprop_attr = self.guided_backprop.attribute.__wrapped__(
206201
self.guided_backprop, # self
207-
inputs=inputs,
202+
inputs=inputs_tuple,
208203
target=target,
209204
additional_forward_args=additional_forward_args,
210205
)
211206
output_attr: List[Tensor] = []
212-
for i in range(len(inputs)):
207+
for i in range(len(inputs_tuple)):
213208
try:
214209
output_attr.append(
215210
guided_backprop_attr[i]
216211
* LayerAttribution.interpolate(
217212
grad_cam_attr,
218-
# pyre-fixme[6]: For 2nd argument expected `Union[int,
219-
# typing.Tuple[int, ...]]` but got `Size`.
220-
inputs[i].shape[2:],
213+
tuple(inputs_tuple[i].shape[2:]),
221214
interpolate_mode=interpolate_mode,
222215
)
223216
)

0 commit comments

Comments
 (0)