@@ -45,8 +45,7 @@ def attribute(
4545 self ,
4646 inputs : TensorOrTupleOfTensorsGeneric ,
4747 target : TargetType = None ,
48- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
49- additional_forward_args : Any = None ,
48+ additional_forward_args : object = None ,
5049 ) -> TensorOrTupleOfTensorsGeneric :
5150 r"""
5251 Computes attribution by overriding relu gradients. Based on constructor
@@ -58,16 +57,10 @@ def attribute(
5857
5958 # Keeps track whether original input is a tuple or not before
6059 # converting it into a tuple.
61- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
62- # `TensorOrTupleOfTensorsGeneric`.
6360 is_inputs_tuple = _is_tuple (inputs )
6461
65- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
66- # `Tuple[Tensor, ...]`.
67- inputs = _format_tensor_into_tuples (inputs )
68- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
69- # `TensorOrTupleOfTensorsGeneric`.
70- gradient_mask = apply_gradient_requirements (inputs )
62+ inputs_tuple = _format_tensor_into_tuples (inputs )
63+ gradient_mask = apply_gradient_requirements (inputs_tuple )
7164
7265 # set hooks for overriding ReLU gradients
7366 warnings .warn (
@@ -79,14 +72,12 @@ def attribute(
7972 self .model .apply (self ._register_hooks )
8073
8174 gradients = self .gradient_func (
82- self .forward_func , inputs , target , additional_forward_args
75+ self .forward_func , inputs_tuple , target , additional_forward_args
8376 )
8477 finally :
8578 self ._remove_hooks ()
8679
87- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
88- # `TensorOrTupleOfTensorsGeneric`.
89- undo_gradient_requirements (inputs , gradient_mask )
80+ undo_gradient_requirements (inputs_tuple , gradient_mask )
9081 # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
9182 # `Tuple[Tensor, ...]`.
9283 return _format_output (is_inputs_tuple , gradients )
@@ -155,8 +146,7 @@ def attribute(
155146 self ,
156147 inputs : TensorOrTupleOfTensorsGeneric ,
157148 target : TargetType = None ,
158- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
159- additional_forward_args : Any = None ,
149+ additional_forward_args : object = None ,
160150 ) -> TensorOrTupleOfTensorsGeneric :
161151 r"""
162152 Args:
@@ -265,8 +255,7 @@ def attribute(
265255 self ,
266256 inputs : TensorOrTupleOfTensorsGeneric ,
267257 target : TargetType = None ,
268- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
269- additional_forward_args : Any = None ,
258+ additional_forward_args : object = None ,
270259 ) -> TensorOrTupleOfTensorsGeneric :
271260 r"""
272261 Args:
0 commit comments