11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import Any , Callable
4+ from typing import Callable
55
66from captum ._utils .common import _format_output , _format_tensor_into_tuples , _is_tuple
77from captum ._utils .gradient import (
1111from captum ._utils .typing import TargetType , TensorOrTupleOfTensorsGeneric
1212from captum .attr ._utils .attribution import GradientAttribution
1313from captum .log import log_usage
14+ from torch import Tensor
1415
1516
1617class InputXGradient (GradientAttribution ):
@@ -20,8 +21,7 @@ class InputXGradient(GradientAttribution):
2021 https://arxiv.org/abs/1605.01713
2122 """
2223
23- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
24- def __init__ (self , forward_func : Callable ) -> None :
24+ def __init__ (self , forward_func : Callable [..., Tensor ]) -> None :
2525 r"""
2626 Args:
2727
@@ -35,8 +35,7 @@ def attribute(
3535 self ,
3636 inputs : TensorOrTupleOfTensorsGeneric ,
3737 target : TargetType = None ,
38- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
39- additional_forward_args : Any = None ,
38+ additional_forward_args : object = None ,
4039 ) -> TensorOrTupleOfTensorsGeneric :
4140 r"""
4241 Args:
@@ -113,28 +112,20 @@ def attribute(
113112 """
114113 # Keeps track whether original input is a tuple or not before
115114 # converting it into a tuple.
116- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
117- # `TensorOrTupleOfTensorsGeneric`.
118115 is_inputs_tuple = _is_tuple (inputs )
119116
120- # pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
121- # `Tuple[Tensor, ...]`.
122- inputs = _format_tensor_into_tuples (inputs )
123- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
124- # `TensorOrTupleOfTensorsGeneric`.
125- gradient_mask = apply_gradient_requirements (inputs )
117+ inputs_tuple = _format_tensor_into_tuples (inputs )
118+ gradient_mask = apply_gradient_requirements (inputs_tuple )
126119
127120 gradients = self .gradient_func (
128- self .forward_func , inputs , target , additional_forward_args
121+ self .forward_func , inputs_tuple , target , additional_forward_args
129122 )
130123
131124 attributions = tuple (
132- input * gradient for input , gradient in zip (inputs , gradients )
125+ input * gradient for input , gradient in zip (inputs_tuple , gradients )
133126 )
134127
135- # pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
136- # `TensorOrTupleOfTensorsGeneric`.
137- undo_gradient_requirements (inputs , gradient_mask )
128+ undo_gradient_requirements (inputs_tuple , gradient_mask )
138129 # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
139130 # `Tuple[Tensor, ...]`.
140131 return _format_output (is_inputs_tuple , attributions )
0 commit comments