Skip to content

Commit 3cc4618

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Support grad_kwargs in LayerIntegratedGradients (#1435)
Summary: Add support to pass grad_kwargs to torch.grad.autograd through LayerIntegratedGradients.attribute Reviewed By: cyrjano Differential Revision: D65610295
1 parent e8b6d98 commit 3cc4618

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44
import functools
55
import warnings
6-
from typing import Callable, cast, List, Literal, Optional, overload, Tuple, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
List,
12+
Literal,
13+
Optional,
14+
overload,
15+
Tuple,
16+
Union,
17+
)
718

819
import torch
920
from captum._utils.common import (
@@ -113,6 +124,7 @@ def _make_gradient_func(
113124
self,
114125
num_outputs_cumsum: Tensor,
115126
attribute_to_layer_input: bool,
127+
grad_kwargs: Optional[Dict[str, Any]],
116128
) -> Callable[..., Tuple[Tensor, ...]]:
117129

118130
def _gradient_func(
@@ -220,7 +232,9 @@ def layer_forward_hook(
220232
)
221233
# torch.unbind(forward_out) is a list of scalar tensor tuples and
222234
# contains batch_size * #steps elements
223-
grads = torch.autograd.grad(torch.unbind(output), inputs)
235+
grads = torch.autograd.grad(
236+
torch.unbind(output), inputs, **grad_kwargs or {}
237+
)
224238
return grads
225239

226240
return _gradient_func
@@ -237,6 +251,7 @@ def attribute(
237251
internal_batch_size: Union[None, int],
238252
return_convergence_delta: Literal[False],
239253
attribute_to_layer_input: bool,
254+
grad_kwargs: Optional[Dict[str, Any]],
240255
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...
241256

242257
@overload
@@ -251,6 +266,7 @@ def attribute( # type: ignore
251266
internal_batch_size: Union[None, int],
252267
return_convergence_delta: Literal[True],
253268
attribute_to_layer_input: bool,
269+
grad_kwargs: Optional[Dict[str, Any]],
254270
) -> Tuple[
255271
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
256272
Tensor,
@@ -270,6 +286,7 @@ def attribute(
270286
internal_batch_size: Union[None, int] = None,
271287
return_convergence_delta: bool = False,
272288
attribute_to_layer_input: bool = False,
289+
grad_kwargs: Optional[Dict[str, Any]] = None,
273290
) -> Union[
274291
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
275292
Tuple[
@@ -292,6 +309,7 @@ def attribute(
292309
internal_batch_size: Union[None, int] = None,
293310
return_convergence_delta: bool = False,
294311
attribute_to_layer_input: bool = False,
312+
grad_kwargs: Optional[Dict[str, Any]] = None,
295313
) -> Union[
296314
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
297315
Tuple[
@@ -427,6 +445,9 @@ def attribute(
427445
attribute to the input or output, is a single tensor.
428446
Support for multiple tensors will be added later.
429447
Default: False
448+
grad_kwargs (Dict[str, Any], optional): Additional keyword
449+
arguments for torch.autograd.grad.
450+
Default: None
430451
431452
Returns:
432453
**attributions** or 2-element tuple of **attributions**, **delta**:
@@ -523,7 +544,7 @@ def flatten_tuple(tup):
523544
# inputs -> these inputs are scaled
524545

525546
self.ig.gradient_func = self._make_gradient_func(
526-
num_outputs_cumsum, attribute_to_layer_input
547+
num_outputs_cumsum, attribute_to_layer_input, grad_kwargs
527548
)
528549
all_inputs = (
529550
(inps + additional_forward_args)

tests/attr/layer/test_layer_integrated_gradients.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
4+
import unittest
55
from typing import Any, cast, List, Tuple, Union
66

77
import torch
@@ -13,6 +13,7 @@
1313
configure_interpretable_embedding_layer,
1414
remove_interpretable_embedding_layer,
1515
)
16+
from packaging import version
1617
from tests.helpers.basic import (
1718
assertTensorAlmostEqual,
1819
assertTensorTuplesAlmostEqual,
@@ -229,6 +230,28 @@ def test_multiple_tensors_compare_with_exp_wo_mult_by_inputs(self) -> None:
229230
attributions,
230231
)
231232

233+
def test_simple_multi_gradient_activation_with_unused_layer(self) -> None:
234+
if version.parse(torch.__version__) < version.parse("2.1.0"):
235+
raise unittest.SkipTest(
236+
"Skipping unused layed gradient test since it is not supported "
237+
"by torch version < 2.1"
238+
)
239+
240+
model = BasicModel_MultiLayer(multi_input_module=True)
241+
test_input = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True)
242+
# pyre-fixme[6]: For 2nd argument expected `ModuleOrModuleList` but got
243+
# `List[Union[ReLU, Linear]]`.
244+
layer_ig = LayerIntegratedGradients(model, [model.linear1, model.relu])
245+
attributions = cast(
246+
List[Tensor],
247+
layer_ig.attribute(
248+
inputs=test_input, target=0, grad_kwargs={"materialize_grads": True}
249+
),
250+
)
251+
self.assertEqual(len(attributions), 2)
252+
self.assertEqual(list(attributions[0].shape), [1, 4])
253+
self.assertEqual(list(attributions[1].shape), [1, 4])
254+
232255
def _assert_compare_with_layer_conductance(
233256
self, model: Module, input: Tensor, attribute_to_layer_input: bool = False
234257
) -> None:

0 commit comments

Comments
 (0)