22
33# pyre-strict
44import warnings
5- from typing import Any , List , Union
5+ from typing import List , Union
66
77import torch
88from captum ._utils .common import _format_output , _format_tensor_into_tuples , _is_tuple
@@ -72,8 +72,7 @@ def attribute(
7272 self ,
7373 inputs : TensorOrTupleOfTensorsGeneric ,
7474 target : TargetType = None ,
75- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
76- additional_forward_args : Any = None ,
75+ additional_forward_args : object = None ,
7776 interpolate_mode : str = "nearest" ,
7877 attribute_to_layer_input : bool = False ,
7978 ) -> TensorOrTupleOfTensorsGeneric :
@@ -181,15 +180,11 @@ def attribute(
181180 >>> # attribution size matches input size, Nx3x32x32
182181 >>> attribution = guided_gc.attribute(input, 3)
183182 """
184- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
185- # `TensorOrTupleOfTensorsGeneric`.
186183 is_inputs_tuple = _is_tuple (inputs )
187- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
188- # `Tuple[Tensor, ...]`.
189- inputs = _format_tensor_into_tuples (inputs )
184+ inputs_tuple = _format_tensor_into_tuples (inputs )
190185 grad_cam_attr = self .grad_cam .attribute .__wrapped__ (
191186 self .grad_cam , # self
192- inputs = inputs ,
187+ inputs = inputs_tuple ,
193188 target = target ,
194189 additional_forward_args = additional_forward_args ,
195190 attribute_to_layer_input = attribute_to_layer_input ,
@@ -204,20 +199,18 @@ def attribute(
204199
205200 guided_backprop_attr = self .guided_backprop .attribute .__wrapped__ (
206201 self .guided_backprop , # self
207- inputs = inputs ,
202+ inputs = inputs_tuple ,
208203 target = target ,
209204 additional_forward_args = additional_forward_args ,
210205 )
211206 output_attr : List [Tensor ] = []
212- for i in range (len (inputs )):
207+ for i in range (len (inputs_tuple )):
213208 try :
214209 output_attr .append (
215210 guided_backprop_attr [i ]
216211 * LayerAttribution .interpolate (
217212 grad_cam_attr ,
218- # pyre-fixme[6]: For 2nd argument expected `Union[int,
219- # typing.Tuple[int, ...]]` but got `Size`.
220- inputs [i ].shape [2 :],
213+ tuple (inputs_tuple [i ].shape [2 :]),
221214 interpolate_mode = interpolate_mode ,
222215 )
223216 )
0 commit comments