Skip to content

Use past_key_values to speed up multi-token target LLM Attribution #1224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,32 @@ def _forward_func(
perturbed_tensor,
inp,
target_tokens,
_inspect_forward,
use_cached_outputs=False,
_inspect_forward=None,
):
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
init_model_inp = perturbed_input

model_inp = init_model_inp
model_kwargs = {"attention_mask": torch.tensor([[1] * model_inp.shape[1]])}

log_prob_list = []
outputs = None
for target_token in target_tokens:
output_logits = self.model.forward(
model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
)
new_token_logits = output_logits.logits[:, -1]
if use_cached_outputs:
if outputs is not None:
model_kwargs = self.model._update_model_kwargs_for_generation(
outputs, model_kwargs
)
model_inputs = self.model.prepare_inputs_for_generation(
model_inp, **model_kwargs
)
outputs = self.model.forward(**model_inputs)
else:
outputs = self.model.forward(
model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
)
new_token_logits = outputs.logits[:, -1]
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

log_prob_list.append(log_probs[0][target_token].detach())
Expand Down Expand Up @@ -292,6 +305,7 @@ def attribute(
target: Union[str, torch.Tensor, None] = None,
num_trials: int = 1,
gen_args: Optional[Dict] = None,
use_cached_outputs: bool = True,
# internal callback hook can be used for logging
_inspect_forward: Optional[Callable] = None,
**kwargs,
Expand Down Expand Up @@ -360,7 +374,12 @@ def attribute(

cur_attr = self.attr_method.attribute(
attr_input,
additional_forward_args=(inp, target_tokens, _inspect_forward),
additional_forward_args=(
inp,
target_tokens,
use_cached_outputs,
_inspect_forward,
),
**kwargs,
)

Expand Down
46 changes: 39 additions & 7 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import copy
from collections import namedtuple
from typing import cast, List, Optional, Union

Expand Down Expand Up @@ -53,9 +54,11 @@ def __init__(self):

def forward(self, input_ids, *args, **kwargs):
emb = self.emb(input_ids)
if "past_key_values" in kwargs:
emb = torch.cat((kwargs["past_key_values"], emb), dim=1)
logits = self.linear(self.trans(emb))
Result = namedtuple("Result", ["logits"])
return Result(logits=logits)
Result = namedtuple("Result", ["logits", "past_key_values"])
return Result(logits=logits, past_key_values=emb)

def generate(self, input_ids, *args, mock_response=None, **kwargs):
assert mock_response, "must mock response to use DummyLLM to geenrate"
Expand All @@ -64,16 +67,35 @@ def generate(self, input_ids, *args, mock_response=None, **kwargs):
[input_ids, torch.tensor([response], device=self.device)], dim=1
)

def _update_model_kwargs_for_generation(self, outputs, model_kwargs):
new_kwargs = copy.deepcopy(model_kwargs)
if hasattr(outputs, "past_key_values"):
new_kwargs["past_key_values"] = outputs.past_key_values
return new_kwargs

def prepare_inputs_for_generation(self, model_inp, **model_kwargs):
if "past_key_values" in model_kwargs:
emb_len = model_kwargs["past_key_values"].shape[1]
return {
"input_ids": model_inp[:, emb_len:],
"past_key_values": model_kwargs["past_key_values"],
}
return {"input_ids": model_inp}

@property
def device(self):
return next(self.parameters()).device


@parameterized_class(
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
("device", "use_cached_outputs"),
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)],
)
class TestLLMAttr(BaseTest):
device: str
use_cached_outputs: bool

@parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)])
def test_llm_attr(self, AttrClass) -> None:
Expand All @@ -83,7 +105,9 @@ def test_llm_attr(self, AttrClass) -> None:
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(inp, "m n o p q")
res = llm_attr.attribute(
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
)

self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
Expand All @@ -100,7 +124,11 @@ def test_llm_attr_without_target(self) -> None:
llm_fa = LLMAttribution(fa, tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_fa.attribute(inp, gen_args={"mock_response": "x y z"})
res = llm_fa.attribute(
inp,
gen_args={"mock_response": "x y z"},
use_cached_outputs=self.use_cached_outputs,
)

self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4))
Expand All @@ -117,7 +145,9 @@ def test_llm_attr_fa_log_prob(self) -> None:
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_fa.attribute(inp, "m n o p q")
res = llm_fa.attribute(
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
)

# With FeatureAblation, the seq attr in log_prob
# equals to the sum of each token attr
Expand All @@ -132,7 +162,9 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_fa.attribute(inp, "m n o p q")
res = llm_fa.attribute(
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
)

self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(res.seq_attr.device.type, self.device)
Expand Down