Skip to content

Commit 35f7af5

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Leaf Warning Fix (#597)
Summary: This removes the resetting of grad attribute to zero, which is causing warnings as mentioned in #491 and #421 . Based on torch [documentation](https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad), resetting of grad is only needed when using torch.autograd.backward, which accumulates results into the grad attribute for leaf nodes. Since we only utilize torch.autograd.grad (with only_inputs always set to True), the gradients obtained in Captum are never actually accumulated into grad attributes, so resetting the attribute is not actually necessary. This also adds a test to confirm that the grad attribute is not altered when gradients are utilized through Saliency. Pull Request resolved: #597 Reviewed By: bilalsal Differential Revision: D26079970 Pulled By: vivekmig fbshipit-source-id: f7ccee02a17f66ee75e2176f1b328672b057dbfa
1 parent dfb3f65 commit 35f7af5

File tree

3 files changed

+14
-20
lines changed

3 files changed

+14
-20
lines changed

captum/_utils/gradient.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,6 @@ def apply_gradient_requirements(inputs: Tuple[Tensor, ...]) -> List[bool]:
5555
"required_grads has been set automatically." % index
5656
)
5757
input.requires_grad_()
58-
if input.grad is not None:
59-
if torch.sum(torch.abs(input.grad)).item() > 1e-7:
60-
warnings.warn(
61-
"Input Tensor %d had a non-zero gradient tensor, "
62-
"which is being reset to 0." % index
63-
)
64-
input.grad.zero_()
6558
return grad_required
6659

6760

@@ -84,9 +77,6 @@ def undo_gradient_requirements(
8477
), "Input tuple length should match gradient mask."
8578
for index, input in enumerate(inputs):
8679
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
87-
if input.grad is not None:
88-
input.grad.detach_()
89-
input.grad.zero_()
9080
if not grad_required[index]:
9181
input.requires_grad_(False)
9282

tests/attr/test_saliency.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def test_saliency_classification_smoothgrad(self) -> None:
136136
def test_saliency_classification_vargrad(self) -> None:
137137
self._saliency_classification_assert(nt_type="vargrad")
138138

139+
def test_saliency_grad_unchanged(self) -> None:
140+
model, inp, grads, add_args = _get_basic_config()
141+
inp.grad = torch.randn_like(inp)
142+
grad = inp.grad.detach().clone()
143+
self._saliency_base_assert(model, inp, grads, add_args)
144+
assertTensorTuplesAlmostEqual(self, inp.grad, grad, delta=0.0)
145+
139146
def _saliency_base_assert(
140147
self,
141148
model: Module,

tests/utils/test_gradient.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python3
22

3-
from typing import List, Tuple, cast
3+
from typing import List, Tuple
44

55
import torch
6-
from torch import Tensor
76

87
from captum._utils.gradient import (
98
apply_gradient_requirements,
@@ -32,10 +31,6 @@ def test_apply_gradient_reqs(self) -> None:
3231
for i in range(len(test_tensor_tuple)):
3332
self.assertTrue(test_tensor_tuple[i].requires_grad)
3433
self.assertEqual(out_mask[i], initial_grads[i])
35-
if test_tensor_tuple[i].grad is not None:
36-
self.assertAlmostEqual(
37-
torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0
38-
)
3934

4035
def test_undo_gradient_reqs(self) -> None:
4136
initial_grads = [False, True, False]
@@ -49,22 +44,24 @@ def test_undo_gradient_reqs(self) -> None:
4944
undo_gradient_requirements(test_tensor_tuple, initial_grads)
5045
for i in range(len(test_tensor_tuple)):
5146
self.assertEqual(test_tensor_tuple[i].requires_grad, initial_grads[i])
52-
if test_tensor_tuple[i].grad is not None:
53-
self.assertAlmostEqual(
54-
torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0
55-
)
5647

5748
def test_gradient_basic(self) -> None:
5849
model = BasicModel()
5950
input = torch.tensor([[5.0]], requires_grad=True)
51+
input.grad = torch.tensor([[9.0]])
6052
grads = compute_gradients(model, input)[0]
6153
assertArraysAlmostEqual(grads.squeeze(0).tolist(), [0.0], delta=0.01)
54+
# Verify grad attribute is not altered
55+
assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [9.0], delta=0.0)
6256

6357
def test_gradient_basic_2(self) -> None:
6458
model = BasicModel()
6559
input = torch.tensor([[-3.0]], requires_grad=True)
60+
input.grad = torch.tensor([[14.0]])
6661
grads = compute_gradients(model, input)[0]
6762
assertArraysAlmostEqual(grads.squeeze(0).tolist(), [1.0], delta=0.01)
63+
# Verify grad attribute is not altered
64+
assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [14.0], delta=0.0)
6865

6966
def test_gradient_multiinput(self) -> None:
7067
model = BasicModel6_MultiTensor()

0 commit comments

Comments
 (0)