From cd685655754234813c22dced936e0392d668aa9b Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Fri, 13 Sep 2024 19:03:13 -0700 Subject: [PATCH] Ensure autograd graphs are freed between attribute calls Summary: Some attributions returned by gradient-based methods still have a `grad_fn` from autograd (e.g. `LayerGradientXActivation`). This diff ensures that the autograd graph is freed between attribute calls within `LLMGradientAttribution` to eliminate this as a potential source of VRAM accumulation. Also wrapped `model.generate` with a `no_grad` context to avoid unecessary memory usage. Differential Revision: D62671994 --- captum/attr/_core/llm_attr.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index b4f5738e92..3832ff4a95 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -574,9 +574,10 @@ def attribute( if not gen_args: gen_args = DEFAULT_GEN_ARGS - model_inp = self._format_model_input(inp.to_model_input()) - output_tokens = self.model.generate(model_inp, **gen_args) - target_tokens = output_tokens[0][model_inp.size(1) :] + with torch.no_grad(): + model_inp = self._format_model_input(inp.to_model_input()) + output_tokens = self.model.generate(model_inp, **gen_args) + target_tokens = output_tokens[0][model_inp.size(1) :] else: assert gen_args is None, "gen_args must be None when target is given" @@ -605,7 +606,7 @@ def attribute( cur_target_idx, ), **kwargs, - ) + ).detach() attr = cast(Tensor, attr) # will have the attr for previous output tokens