diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index d692873328..be9d793f5e 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -232,19 +232,32 @@ def _forward_func( perturbed_tensor, inp, target_tokens, - _inspect_forward, + use_cached_outputs=False, + _inspect_forward=None, ): perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) init_model_inp = perturbed_input model_inp = init_model_inp + model_kwargs = {"attention_mask": torch.tensor([[1] * model_inp.shape[1]])} log_prob_list = [] + outputs = None for target_token in target_tokens: - output_logits = self.model.forward( - model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]]) - ) - new_token_logits = output_logits.logits[:, -1] + if use_cached_outputs: + if outputs is not None: + model_kwargs = self.model._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + model_inputs = self.model.prepare_inputs_for_generation( + model_inp, **model_kwargs + ) + outputs = self.model.forward(**model_inputs) + else: + outputs = self.model.forward( + model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]]) + ) + new_token_logits = outputs.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1) log_prob_list.append(log_probs[0][target_token].detach()) @@ -292,6 +305,7 @@ def attribute( target: Union[str, torch.Tensor, None] = None, num_trials: int = 1, gen_args: Optional[Dict] = None, + use_cached_outputs: bool = True, # internal callback hook can be used for logging _inspect_forward: Optional[Callable] = None, **kwargs, @@ -360,7 +374,12 @@ def attribute( cur_attr = self.attr_method.attribute( attr_input, - additional_forward_args=(inp, target_tokens, _inspect_forward), + additional_forward_args=( + inp, + target_tokens, + use_cached_outputs, + _inspect_forward, + ), **kwargs, ) diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 17aca630a9..87a47e68a4 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import copy from collections import namedtuple from typing import cast, List, Optional, Union @@ -53,9 +54,11 @@ def __init__(self): def forward(self, input_ids, *args, **kwargs): emb = self.emb(input_ids) + if "past_key_values" in kwargs: + emb = torch.cat((kwargs["past_key_values"], emb), dim=1) logits = self.linear(self.trans(emb)) - Result = namedtuple("Result", ["logits"]) - return Result(logits=logits) + Result = namedtuple("Result", ["logits", "past_key_values"]) + return Result(logits=logits, past_key_values=emb) def generate(self, input_ids, *args, mock_response=None, **kwargs): assert mock_response, "must mock response to use DummyLLM to geenrate" @@ -64,16 +67,35 @@ def generate(self, input_ids, *args, mock_response=None, **kwargs): [input_ids, torch.tensor([response], device=self.device)], dim=1 ) + def _update_model_kwargs_for_generation(self, outputs, model_kwargs): + new_kwargs = copy.deepcopy(model_kwargs) + if hasattr(outputs, "past_key_values"): + new_kwargs["past_key_values"] = outputs.past_key_values + return new_kwargs + + def prepare_inputs_for_generation(self, model_inp, **model_kwargs): + if "past_key_values" in model_kwargs: + emb_len = model_kwargs["past_key_values"].shape[1] + return { + "input_ids": model_inp[:, emb_len:], + "past_key_values": model_kwargs["past_key_values"], + } + return {"input_ids": model_inp} + @property def device(self): return next(self.parameters()).device @parameterized_class( - ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] + ("device", "use_cached_outputs"), + [("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)] + if torch.cuda.is_available() + else [("cpu", True), ("cpu", False)], ) class TestLLMAttr(BaseTest): device: str + use_cached_outputs: bool @parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)]) def test_llm_attr(self, AttrClass) -> None: @@ -83,7 +105,9 @@ def test_llm_attr(self, AttrClass) -> None: llm_attr = LLMAttribution(AttrClass(llm), tokenizer) inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) - res = llm_attr.attribute(inp, "m n o p q") + res = llm_attr.attribute( + inp, "m n o p q", use_cached_outputs=self.use_cached_outputs + ) self.assertEqual(res.seq_attr.shape, (4,)) self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) @@ -100,7 +124,11 @@ def test_llm_attr_without_target(self) -> None: llm_fa = LLMAttribution(fa, tokenizer) inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) - res = llm_fa.attribute(inp, gen_args={"mock_response": "x y z"}) + res = llm_fa.attribute( + inp, + gen_args={"mock_response": "x y z"}, + use_cached_outputs=self.use_cached_outputs, + ) self.assertEqual(res.seq_attr.shape, (4,)) self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4)) @@ -117,7 +145,9 @@ def test_llm_attr_fa_log_prob(self) -> None: llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob") inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) - res = llm_fa.attribute(inp, "m n o p q") + res = llm_fa.attribute( + inp, "m n o p q", use_cached_outputs=self.use_cached_outputs + ) # With FeatureAblation, the seq attr in log_prob # equals to the sum of each token attr @@ -132,7 +162,9 @@ def test_llm_attr_without_token(self, AttrClass) -> None: llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob") inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) - res = llm_fa.attribute(inp, "m n o p q") + res = llm_fa.attribute( + inp, "m n o p q", use_cached_outputs=self.use_cached_outputs + ) self.assertEqual(res.seq_attr.shape, (4,)) self.assertEqual(res.seq_attr.device.type, self.device)