Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions captum/attr/_core/layer/layer_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
# pyre-strict
import functools
import warnings
from typing import Callable, cast, List, Literal, Optional, overload, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
Optional,
overload,
Tuple,
Union,
)

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -113,6 +124,7 @@ def _make_gradient_func(
self,
num_outputs_cumsum: Tensor,
attribute_to_layer_input: bool,
grad_kwargs: Optional[Dict[str, Any]],
) -> Callable[..., Tuple[Tensor, ...]]:

def _gradient_func(
Expand Down Expand Up @@ -220,7 +232,9 @@ def layer_forward_hook(
)
# torch.unbind(forward_out) is a list of scalar tensor tuples and
# contains batch_size * #steps elements
grads = torch.autograd.grad(torch.unbind(output), inputs)
grads = torch.autograd.grad(
torch.unbind(output), inputs, **grad_kwargs or {}
)
return grads

return _gradient_func
Expand All @@ -237,6 +251,7 @@ def attribute(
internal_batch_size: Union[None, int],
return_convergence_delta: Literal[False],
attribute_to_layer_input: bool,
grad_kwargs: Optional[Dict[str, Any]],
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ...

@overload
Expand All @@ -251,6 +266,7 @@ def attribute( # type: ignore
internal_batch_size: Union[None, int],
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool,
grad_kwargs: Optional[Dict[str, Any]],
) -> Tuple[
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
Tensor,
Expand All @@ -270,6 +286,7 @@ def attribute(
internal_batch_size: Union[None, int] = None,
return_convergence_delta: bool = False,
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
Tuple[
Expand All @@ -292,6 +309,7 @@ def attribute(
internal_batch_size: Union[None, int] = None,
return_convergence_delta: bool = False,
attribute_to_layer_input: bool = False,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]],
Tuple[
Expand Down Expand Up @@ -427,6 +445,9 @@ def attribute(
attribute to the input or output, is a single tensor.
Support for multiple tensors will be added later.
Default: False
grad_kwargs (Dict[str, Any], optional): Additional keyword
arguments for torch.autograd.grad.
Default: None

Returns:
**attributions** or 2-element tuple of **attributions**, **delta**:
Expand Down Expand Up @@ -523,7 +544,7 @@ def flatten_tuple(tup):
# inputs -> these inputs are scaled

self.ig.gradient_func = self._make_gradient_func(
num_outputs_cumsum, attribute_to_layer_input
num_outputs_cumsum, attribute_to_layer_input, grad_kwargs
)
all_inputs = (
(inps + additional_forward_args)
Expand Down
25 changes: 24 additions & 1 deletion tests/attr/layer/test_layer_integrated_gradients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict

import unittest
from typing import Any, cast, List, Tuple, Union

import torch
Expand All @@ -13,6 +13,7 @@
configure_interpretable_embedding_layer,
remove_interpretable_embedding_layer,
)
from packaging import version
from tests.helpers.basic import (
assertTensorAlmostEqual,
assertTensorTuplesAlmostEqual,
Expand Down Expand Up @@ -229,6 +230,28 @@ def test_multiple_tensors_compare_with_exp_wo_mult_by_inputs(self) -> None:
attributions,
)

def test_simple_multi_gradient_activation_with_unused_layer(self) -> None:
if version.parse(torch.__version__) < version.parse("2.1.0"):
raise unittest.SkipTest(
"Skipping unused layed gradient test since it is not supported "
"by torch version < 2.1"
)

model = BasicModel_MultiLayer(multi_input_module=True)
test_input = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True)
# pyre-fixme[6]: For 2nd argument expected `ModuleOrModuleList` but got
# `List[Union[ReLU, Linear]]`.
layer_ig = LayerIntegratedGradients(model, [model.linear1, model.relu])
attributions = cast(
List[Tensor],
layer_ig.attribute(
inputs=test_input, target=0, grad_kwargs={"materialize_grads": True}
),
)
self.assertEqual(len(attributions), 2)
self.assertEqual(list(attributions[0].shape), [1, 4])
self.assertEqual(list(attributions[1].shape), [1, 4])

def _assert_compare_with_layer_conductance(
self, model: Module, input: Tensor, attribute_to_layer_input: bool = False
) -> None:
Expand Down
Loading