33# pyre-strict
44import functools
55import warnings
6- from typing import Callable , cast , List , Literal , Optional , overload , Tuple , Union
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ cast ,
10+ Dict ,
11+ List ,
12+ Literal ,
13+ Optional ,
14+ overload ,
15+ Tuple ,
16+ Union ,
17+ )
718
819import torch
920from captum ._utils .common import (
@@ -113,6 +124,7 @@ def _make_gradient_func(
113124 self ,
114125 num_outputs_cumsum : Tensor ,
115126 attribute_to_layer_input : bool ,
127+ grad_kwargs : Optional [Dict [str , Any ]],
116128 ) -> Callable [..., Tuple [Tensor , ...]]:
117129
118130 def _gradient_func (
@@ -220,7 +232,9 @@ def layer_forward_hook(
220232 )
221233 # torch.unbind(forward_out) is a list of scalar tensor tuples and
222234 # contains batch_size * #steps elements
223- grads = torch .autograd .grad (torch .unbind (output ), inputs )
235+ grads = torch .autograd .grad (
236+ torch .unbind (output ), inputs , ** grad_kwargs or {}
237+ )
224238 return grads
225239
226240 return _gradient_func
@@ -237,6 +251,7 @@ def attribute(
237251 internal_batch_size : Union [None , int ],
238252 return_convergence_delta : Literal [False ],
239253 attribute_to_layer_input : bool ,
254+ grad_kwargs : Optional [Dict [str , Any ]],
240255 ) -> Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]]: ...
241256
242257 @overload
@@ -251,6 +266,7 @@ def attribute( # type: ignore
251266 internal_batch_size : Union [None , int ],
252267 return_convergence_delta : Literal [True ],
253268 attribute_to_layer_input : bool ,
269+ grad_kwargs : Optional [Dict [str , Any ]],
254270 ) -> Tuple [
255271 Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
256272 Tensor ,
@@ -270,6 +286,7 @@ def attribute(
270286 internal_batch_size : Union [None , int ] = None ,
271287 return_convergence_delta : bool = False ,
272288 attribute_to_layer_input : bool = False ,
289+ grad_kwargs : Optional [Dict [str , Any ]] = None ,
273290 ) -> Union [
274291 Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
275292 Tuple [
@@ -292,6 +309,7 @@ def attribute(
292309 internal_batch_size : Union [None , int ] = None ,
293310 return_convergence_delta : bool = False ,
294311 attribute_to_layer_input : bool = False ,
312+ grad_kwargs : Optional [Dict [str , Any ]] = None ,
295313 ) -> Union [
296314 Union [Tensor , Tuple [Tensor , ...], List [Union [Tensor , Tuple [Tensor , ...]]]],
297315 Tuple [
@@ -427,6 +445,9 @@ def attribute(
427445 attribute to the input or output, is a single tensor.
428446 Support for multiple tensors will be added later.
429447 Default: False
448+ grad_kwargs (Dict[str, Any], optional): Additional keyword
449+ arguments for torch.autograd.grad.
450+ Default: None
430451
431452 Returns:
432453 **attributions** or 2-element tuple of **attributions**, **delta**:
@@ -523,7 +544,7 @@ def flatten_tuple(tup):
523544 # inputs -> these inputs are scaled
524545
525546 self .ig .gradient_func = self ._make_gradient_func (
526- num_outputs_cumsum , attribute_to_layer_input
547+ num_outputs_cumsum , attribute_to_layer_input , grad_kwargs
527548 )
528549 all_inputs = (
529550 (inps + additional_forward_args )
0 commit comments