diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 01067ca0db..406acef968 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -3,7 +3,18 @@ # pyre-strict import functools import warnings -from typing import Callable, cast, List, Literal, Optional, overload, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Literal, + Optional, + overload, + Tuple, + Union, +) import torch from captum._utils.common import ( @@ -113,6 +124,7 @@ def _make_gradient_func( self, num_outputs_cumsum: Tensor, attribute_to_layer_input: bool, + grad_kwargs: Optional[Dict[str, Any]], ) -> Callable[..., Tuple[Tensor, ...]]: def _gradient_func( @@ -220,7 +232,9 @@ def layer_forward_hook( ) # torch.unbind(forward_out) is a list of scalar tensor tuples and # contains batch_size * #steps elements - grads = torch.autograd.grad(torch.unbind(output), inputs) + grads = torch.autograd.grad( + torch.unbind(output), inputs, **grad_kwargs or {} + ) return grads return _gradient_func @@ -237,6 +251,7 @@ def attribute( internal_batch_size: Union[None, int], return_convergence_delta: Literal[False], attribute_to_layer_input: bool, + grad_kwargs: Optional[Dict[str, Any]], ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... @overload @@ -251,6 +266,7 @@ def attribute( # type: ignore internal_batch_size: Union[None, int], return_convergence_delta: Literal[True], attribute_to_layer_input: bool, + grad_kwargs: Optional[Dict[str, Any]], ) -> Tuple[ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], Tensor, @@ -270,6 +286,7 @@ def attribute( internal_batch_size: Union[None, int] = None, return_convergence_delta: bool = False, attribute_to_layer_input: bool = False, + grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], Tuple[ @@ -292,6 +309,7 @@ def attribute( internal_batch_size: Union[None, int] = None, return_convergence_delta: bool = False, attribute_to_layer_input: bool = False, + grad_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], Tuple[ @@ -427,6 +445,9 @@ def attribute( attribute to the input or output, is a single tensor. Support for multiple tensors will be added later. Default: False + grad_kwargs (Dict[str, Any], optional): Additional keyword + arguments for torch.autograd.grad. + Default: None Returns: **attributions** or 2-element tuple of **attributions**, **delta**: @@ -523,7 +544,7 @@ def flatten_tuple(tup): # inputs -> these inputs are scaled self.ig.gradient_func = self._make_gradient_func( - num_outputs_cumsum, attribute_to_layer_input + num_outputs_cumsum, attribute_to_layer_input, grad_kwargs ) all_inputs = ( (inps + additional_forward_args) diff --git a/tests/attr/layer/test_layer_integrated_gradients.py b/tests/attr/layer/test_layer_integrated_gradients.py index 4a90422827..0ad158ed6d 100644 --- a/tests/attr/layer/test_layer_integrated_gradients.py +++ b/tests/attr/layer/test_layer_integrated_gradients.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict - +import unittest from typing import Any, cast, List, Tuple, Union import torch @@ -13,6 +13,7 @@ configure_interpretable_embedding_layer, remove_interpretable_embedding_layer, ) +from packaging import version from tests.helpers.basic import ( assertTensorAlmostEqual, assertTensorTuplesAlmostEqual, @@ -229,6 +230,28 @@ def test_multiple_tensors_compare_with_exp_wo_mult_by_inputs(self) -> None: attributions, ) + def test_simple_multi_gradient_activation_with_unused_layer(self) -> None: + if version.parse(torch.__version__) < version.parse("2.1.0"): + raise unittest.SkipTest( + "Skipping unused layed gradient test since it is not supported " + "by torch version < 2.1" + ) + + model = BasicModel_MultiLayer(multi_input_module=True) + test_input = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True) + # pyre-fixme[6]: For 2nd argument expected `ModuleOrModuleList` but got + # `List[Union[ReLU, Linear]]`. + layer_ig = LayerIntegratedGradients(model, [model.linear1, model.relu]) + attributions = cast( + List[Tensor], + layer_ig.attribute( + inputs=test_input, target=0, grad_kwargs={"materialize_grads": True} + ), + ) + self.assertEqual(len(attributions), 2) + self.assertEqual(list(attributions[0].shape), [1, 4]) + self.assertEqual(list(attributions[1].shape), [1, 4]) + def _assert_compare_with_layer_conductance( self, model: Module, input: Tensor, attribute_to_layer_input: bool = False ) -> None: