diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index b4f5738e92..3832ff4a95 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -574,9 +574,10 @@ def attribute( if not gen_args: gen_args = DEFAULT_GEN_ARGS - model_inp = self._format_model_input(inp.to_model_input()) - output_tokens = self.model.generate(model_inp, **gen_args) - target_tokens = output_tokens[0][model_inp.size(1) :] + with torch.no_grad(): + model_inp = self._format_model_input(inp.to_model_input()) + output_tokens = self.model.generate(model_inp, **gen_args) + target_tokens = output_tokens[0][model_inp.size(1) :] else: assert gen_args is None, "gen_args must be None when target is given" @@ -605,7 +606,7 @@ def attribute( cur_target_idx, ), **kwargs, - ) + ).detach() attr = cast(Tensor, attr) # will have the attr for previous output tokens