diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index 1f02895110..6f28acc685 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -55,13 +55,6 @@ def apply_gradient_requirements(inputs: Tuple[Tensor, ...]) -> List[bool]: "required_grads has been set automatically." % index ) input.requires_grad_() - if input.grad is not None: - if torch.sum(torch.abs(input.grad)).item() > 1e-7: - warnings.warn( - "Input Tensor %d had a non-zero gradient tensor, " - "which is being reset to 0." % index - ) - input.grad.zero_() return grad_required @@ -84,9 +77,6 @@ def undo_gradient_requirements( ), "Input tuple length should match gradient mask." for index, input in enumerate(inputs): assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor" - if input.grad is not None: - input.grad.detach_() - input.grad.zero_() if not grad_required[index]: input.requires_grad_(False) diff --git a/tests/attr/test_saliency.py b/tests/attr/test_saliency.py index 12b2a7687f..38992f05c6 100644 --- a/tests/attr/test_saliency.py +++ b/tests/attr/test_saliency.py @@ -136,6 +136,13 @@ def test_saliency_classification_smoothgrad(self) -> None: def test_saliency_classification_vargrad(self) -> None: self._saliency_classification_assert(nt_type="vargrad") + def test_saliency_grad_unchanged(self) -> None: + model, inp, grads, add_args = _get_basic_config() + inp.grad = torch.randn_like(inp) + grad = inp.grad.detach().clone() + self._saliency_base_assert(model, inp, grads, add_args) + assertTensorTuplesAlmostEqual(self, inp.grad, grad, delta=0.0) + def _saliency_base_assert( self, model: Module, diff --git a/tests/utils/test_gradient.py b/tests/utils/test_gradient.py index 192894bd93..ceff1616cb 100644 --- a/tests/utils/test_gradient.py +++ b/tests/utils/test_gradient.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 -from typing import List, Tuple, cast +from typing import List, Tuple import torch -from torch import Tensor from captum._utils.gradient import ( apply_gradient_requirements, @@ -32,10 +31,6 @@ def test_apply_gradient_reqs(self) -> None: for i in range(len(test_tensor_tuple)): self.assertTrue(test_tensor_tuple[i].requires_grad) self.assertEqual(out_mask[i], initial_grads[i]) - if test_tensor_tuple[i].grad is not None: - self.assertAlmostEqual( - torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0 - ) def test_undo_gradient_reqs(self) -> None: initial_grads = [False, True, False] @@ -49,22 +44,24 @@ def test_undo_gradient_reqs(self) -> None: undo_gradient_requirements(test_tensor_tuple, initial_grads) for i in range(len(test_tensor_tuple)): self.assertEqual(test_tensor_tuple[i].requires_grad, initial_grads[i]) - if test_tensor_tuple[i].grad is not None: - self.assertAlmostEqual( - torch.sum(cast(Tensor, test_tensor_tuple[i].grad)).item(), 0.0 - ) def test_gradient_basic(self) -> None: model = BasicModel() input = torch.tensor([[5.0]], requires_grad=True) + input.grad = torch.tensor([[9.0]]) grads = compute_gradients(model, input)[0] assertArraysAlmostEqual(grads.squeeze(0).tolist(), [0.0], delta=0.01) + # Verify grad attribute is not altered + assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [9.0], delta=0.0) def test_gradient_basic_2(self) -> None: model = BasicModel() input = torch.tensor([[-3.0]], requires_grad=True) + input.grad = torch.tensor([[14.0]]) grads = compute_gradients(model, input)[0] assertArraysAlmostEqual(grads.squeeze(0).tolist(), [1.0], delta=0.01) + # Verify grad attribute is not altered + assertArraysAlmostEqual(input.grad.squeeze(0).tolist(), [14.0], delta=0.0) def test_gradient_multiinput(self) -> None: model = BasicModel6_MultiTensor()