Skip to content

Commit fefcb5b

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Use past_key_values to speed up multi-token target LLM Attribution (#1224)
Summary: Default generation in transformers utilizes past_key_values to cache previous key values to speed up forward passes for subsequent tokens. This adds a flag and use of corresponding helpers from transformers generation utils to follow the same approach for using caching. Using this flag leads to about a 10x speedup with 10 target tokens, and improvement seems to scale with number of target tokens. Pull Request resolved: #1224 Reviewed By: aobo-y Differential Revision: D52240469 Pulled By: vivekmig fbshipit-source-id: e643458529091fb5540b0b0a374ceb0c2c25e394
1 parent 68d88cf commit fefcb5b

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,32 @@ def _forward_func(
232232
perturbed_tensor,
233233
inp,
234234
target_tokens,
235-
_inspect_forward,
235+
use_cached_outputs=False,
236+
_inspect_forward=None,
236237
):
237238
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
238239
init_model_inp = perturbed_input
239240

240241
model_inp = init_model_inp
242+
model_kwargs = {"attention_mask": torch.tensor([[1] * model_inp.shape[1]])}
241243

242244
log_prob_list = []
245+
outputs = None
243246
for target_token in target_tokens:
244-
output_logits = self.model.forward(
245-
model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
246-
)
247-
new_token_logits = output_logits.logits[:, -1]
247+
if use_cached_outputs:
248+
if outputs is not None:
249+
model_kwargs = self.model._update_model_kwargs_for_generation(
250+
outputs, model_kwargs
251+
)
252+
model_inputs = self.model.prepare_inputs_for_generation(
253+
model_inp, **model_kwargs
254+
)
255+
outputs = self.model.forward(**model_inputs)
256+
else:
257+
outputs = self.model.forward(
258+
model_inp, attention_mask=torch.tensor([[1] * model_inp.shape[1]])
259+
)
260+
new_token_logits = outputs.logits[:, -1]
248261
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
249262

250263
log_prob_list.append(log_probs[0][target_token].detach())
@@ -292,6 +305,7 @@ def attribute(
292305
target: Union[str, torch.Tensor, None] = None,
293306
num_trials: int = 1,
294307
gen_args: Optional[Dict] = None,
308+
use_cached_outputs: bool = True,
295309
# internal callback hook can be used for logging
296310
_inspect_forward: Optional[Callable] = None,
297311
**kwargs,
@@ -360,7 +374,12 @@ def attribute(
360374

361375
cur_attr = self.attr_method.attribute(
362376
attr_input,
363-
additional_forward_args=(inp, target_tokens, _inspect_forward),
377+
additional_forward_args=(
378+
inp,
379+
target_tokens,
380+
use_cached_outputs,
381+
_inspect_forward,
382+
),
364383
**kwargs,
365384
)
366385

tests/attr/test_llm_attr.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import copy
34
from collections import namedtuple
45
from typing import cast, List, Optional, Union
56

@@ -53,9 +54,11 @@ def __init__(self):
5354

5455
def forward(self, input_ids, *args, **kwargs):
5556
emb = self.emb(input_ids)
57+
if "past_key_values" in kwargs:
58+
emb = torch.cat((kwargs["past_key_values"], emb), dim=1)
5659
logits = self.linear(self.trans(emb))
57-
Result = namedtuple("Result", ["logits"])
58-
return Result(logits=logits)
60+
Result = namedtuple("Result", ["logits", "past_key_values"])
61+
return Result(logits=logits, past_key_values=emb)
5962

6063
def generate(self, input_ids, *args, mock_response=None, **kwargs):
6164
assert mock_response, "must mock response to use DummyLLM to geenrate"
@@ -64,16 +67,35 @@ def generate(self, input_ids, *args, mock_response=None, **kwargs):
6467
[input_ids, torch.tensor([response], device=self.device)], dim=1
6568
)
6669

70+
def _update_model_kwargs_for_generation(self, outputs, model_kwargs):
71+
new_kwargs = copy.deepcopy(model_kwargs)
72+
if hasattr(outputs, "past_key_values"):
73+
new_kwargs["past_key_values"] = outputs.past_key_values
74+
return new_kwargs
75+
76+
def prepare_inputs_for_generation(self, model_inp, **model_kwargs):
77+
if "past_key_values" in model_kwargs:
78+
emb_len = model_kwargs["past_key_values"].shape[1]
79+
return {
80+
"input_ids": model_inp[:, emb_len:],
81+
"past_key_values": model_kwargs["past_key_values"],
82+
}
83+
return {"input_ids": model_inp}
84+
6785
@property
6886
def device(self):
6987
return next(self.parameters()).device
7088

7189

7290
@parameterized_class(
73-
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
91+
("device", "use_cached_outputs"),
92+
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
93+
if torch.cuda.is_available()
94+
else [("cpu", True), ("cpu", False)],
7495
)
7596
class TestLLMAttr(BaseTest):
7697
device: str
98+
use_cached_outputs: bool
7799

78100
@parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)])
79101
def test_llm_attr(self, AttrClass) -> None:
@@ -83,7 +105,9 @@ def test_llm_attr(self, AttrClass) -> None:
83105
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)
84106

85107
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
86-
res = llm_attr.attribute(inp, "m n o p q")
108+
res = llm_attr.attribute(
109+
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
110+
)
87111

88112
self.assertEqual(res.seq_attr.shape, (4,))
89113
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
@@ -100,7 +124,11 @@ def test_llm_attr_without_target(self) -> None:
100124
llm_fa = LLMAttribution(fa, tokenizer)
101125

102126
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
103-
res = llm_fa.attribute(inp, gen_args={"mock_response": "x y z"})
127+
res = llm_fa.attribute(
128+
inp,
129+
gen_args={"mock_response": "x y z"},
130+
use_cached_outputs=self.use_cached_outputs,
131+
)
104132

105133
self.assertEqual(res.seq_attr.shape, (4,))
106134
self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4))
@@ -117,7 +145,9 @@ def test_llm_attr_fa_log_prob(self) -> None:
117145
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
118146

119147
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
120-
res = llm_fa.attribute(inp, "m n o p q")
148+
res = llm_fa.attribute(
149+
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
150+
)
121151

122152
# With FeatureAblation, the seq attr in log_prob
123153
# equals to the sum of each token attr
@@ -132,7 +162,9 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
132162
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
133163

134164
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
135-
res = llm_fa.attribute(inp, "m n o p q")
165+
res = llm_fa.attribute(
166+
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
167+
)
136168

137169
self.assertEqual(res.seq_attr.shape, (4,))
138170
self.assertEqual(res.seq_attr.device.type, self.device)

0 commit comments

Comments
 (0)