Skip to content

Commit 36a3549

Browse files
csauperfacebook-github-bot
authored andcommitted
stop stripping first character from output string (#1351)
Summary: Pull Request resolved: #1351 previously first character was stripped as SOS token, but that doesn't actually seem to be the case with current LLMs. Keep all tokens. Differential Revision: D62775617
1 parent 70619a6 commit 36a3549

File tree

3 files changed

+133
-16
lines changed

3 files changed

+133
-16
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def attribute(
383383
self,
384384
inp: InterpretableInput,
385385
target: Union[str, torch.Tensor, None] = None,
386+
skip_tokens: Union[List[int], List[str], None] = None,
386387
num_trials: int = 1,
387388
gen_args: Optional[Dict[str, Any]] = None,
388389
use_cached_outputs: bool = True,
@@ -397,6 +398,12 @@ def attribute(
397398
which attributions are computed. If None, it uses the model
398399
to generate the target based on the input and gen_args.
399400
Default: None
401+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
402+
the output's interpretable representation. Use this argument to define
403+
uninterested tokens, commonly like special tokens, e.g., sos, and unk.
404+
It can be a list of strings of the tokens or a list of integers of the
405+
token ids.
406+
Default: None
400407
num_trials (int, optional): number of trials to run. Return is the average
401408
attribibutions over all the trials.
402409
Defaults: 1.
@@ -435,9 +442,20 @@ def attribute(
435442
assert gen_args is None, "gen_args must be None when target is given"
436443

437444
if type(target) is str:
438-
# exclude sos
439-
target_tokens = self.tokenizer.encode(target)[1:]
440-
target_tokens = torch.tensor(target_tokens)
445+
encoded = self.tokenizer.encode(target)
446+
447+
if skip_tokens:
448+
if isinstance(skip_tokens[0], str):
449+
skip_tokens = cast(List[str], skip_tokens)
450+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
451+
assert isinstance(skip_tokens, list)
452+
453+
skip_token_set = set(skip_tokens)
454+
encoded = [
455+
token for token in encoded if token not in skip_token_set
456+
]
457+
458+
target_tokens = torch.tensor(encoded)
441459
elif type(target) is torch.Tensor:
442460
target_tokens = target
443461
else:
@@ -562,6 +580,7 @@ def attribute(
562580
self,
563581
inp: InterpretableInput,
564582
target: Union[str, torch.Tensor, None] = None,
583+
skip_tokens: Union[List[int], List[str], None] = None,
565584
gen_args: Optional[Dict[str, Any]] = None,
566585
**kwargs: Any,
567586
) -> LLMAttributionResult:
@@ -572,6 +591,12 @@ def attribute(
572591
which attributions are computed. If None, it uses the model
573592
to generate the target based on the input and gen_args.
574593
Default: None
594+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
595+
the output's interpretable representation. Use this argument to define
596+
uninterested tokens, commonly like special tokens, e.g., sos, and unk.
597+
It can be a list of strings of the tokens or a list of integers of the
598+
token ids.
599+
Default: None
575600
gen_args (dict, optional): arguments for generating the target. Only used if
576601
target is not given. When None, the default arguments are used,
577602
{"max_new_tokens": 25, "do_sample": False,
@@ -607,9 +632,20 @@ def attribute(
607632
assert gen_args is None, "gen_args must be None when target is given"
608633

609634
if type(target) is str:
610-
# exclude sos
611-
target_tokens = self.tokenizer.encode(target)[1:]
612-
target_tokens = torch.tensor(target_tokens)
635+
encoded = self.tokenizer.encode(target)
636+
637+
if skip_tokens:
638+
if isinstance(skip_tokens[0], str):
639+
skip_tokens = cast(List[str], skip_tokens)
640+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
641+
assert isinstance(skip_tokens, list)
642+
643+
skip_token_set = set(skip_tokens)
644+
encoded = [
645+
token for token in encoded if token not in skip_token_set
646+
]
647+
648+
target_tokens = torch.tensor(encoded)
613649
elif type(target) is torch.Tensor:
614650
target_tokens = target
615651
else:

tests/attr/test_llm_attr.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ def convert_tokens_to_ids(
9595
raise NotImplementedError
9696

9797
def decode(self, token_ids: Tensor) -> str:
98-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
98+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
99+
if isinstance(tokens, list):
100+
tokens = " ".join(tokens)
101+
return tokens
99102

100103

101104
class Result(NamedTuple):
@@ -271,6 +274,7 @@ def test_llm_attr(
271274
res = llm_attr.attribute(
272275
inp,
273276
"m n o p q",
277+
skip_tokens=[0],
274278
use_cached_outputs=self.use_cached_outputs,
275279
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
276280
# for 4th positional argument, expected
@@ -330,7 +334,10 @@ def test_llm_attr_fa_log_prob(self) -> None:
330334

331335
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
332336
res = llm_fa.attribute(
333-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
337+
inp,
338+
"m n o p q",
339+
skip_tokens=[0],
340+
use_cached_outputs=self.use_cached_outputs,
334341
)
335342

336343
# With FeatureAblation, the seq attr in log_prob
@@ -385,6 +392,7 @@ def test_llm_attr_without_token(
385392
res = llm_fa.attribute(
386393
inp,
387394
"m n o p q",
395+
skip_tokens=[0],
388396
use_cached_outputs=self.use_cached_outputs,
389397
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
390398
# for 4th positional argument, expected
@@ -448,7 +456,7 @@ def test_llm_attr(
448456
)
449457

450458
inp = TextTokenInput("a b c", tokenizer)
451-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
459+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
452460

453461
# 5 output tokens, 4 input tokens including sos
454462
self.assertEqual(res.seq_attr.shape, (4,))
@@ -523,7 +531,7 @@ def test_llm_attr_with_skip_tokens(
523531
)
524532

525533
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
526-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
534+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
527535

528536
# 5 output tokens, 4 input tokens including sos
529537
self.assertEqual(res.seq_attr.shape, (3,))
@@ -537,3 +545,41 @@ def test_llm_attr_with_skip_tokens(
537545
self.assertEqual(res.seq_attr.device.type, self.device)
538546
assert res.token_attr is not None # make pyre/mypy happy
539547
self.assertEqual(token_attr.device.type, self.device) # type: ignore
548+
549+
@parameterized.expand(
550+
[
551+
(LayerIntegratedGradients, None),
552+
(LayerGradientXActivation, None),
553+
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
554+
]
555+
)
556+
def test_llm_attr_with_no_skip_tokens(
557+
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
558+
) -> None:
559+
llm = DummyLLM()
560+
llm.to(self.device)
561+
tokenizer = DummyTokenizer()
562+
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
563+
llm_attr = LLMGradientAttribution(attr, tokenizer)
564+
565+
attr_kws: Dict[str, Any] = {}
566+
if baselines is not None:
567+
attr_kws["baselines"] = tuple(
568+
baseline.to(self.device) for baseline in baselines
569+
)
570+
571+
inp = TextTokenInput("a b c", tokenizer)
572+
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
573+
574+
# 5 output tokens, 4 input tokens including sos
575+
self.assertEqual(res.seq_attr.shape, (4,))
576+
assert res.token_attr is not None # make pyre/mypy happy
577+
self.assertIsNotNone(res.token_attr)
578+
token_attr = res.token_attr
579+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
580+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
581+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
582+
583+
self.assertEqual(res.seq_attr.device.type, self.device)
584+
assert res.token_attr is not None # make pyre/mypy happy
585+
self.assertEqual(token_attr.device.type, self.device) # type: ignore

tests/attr/test_llm_attr_gpu.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ def convert_tokens_to_ids(
8484
raise NotImplementedError
8585

8686
def decode(self, token_ids: Tensor) -> str:
87-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
87+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
88+
if isinstance(tokens, list):
89+
tokens = " ".join(tokens)
90+
return tokens
8891

8992

9093
class Result(NamedTuple):
@@ -195,7 +198,10 @@ def test_llm_attr_gpu(self, AttrClass: Type[PerturbationAttribution]) -> None:
195198

196199
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
197200
res = llm_attr.attribute(
198-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
201+
inp,
202+
"m n o p q",
203+
skip_tokens=[0],
204+
use_cached_outputs=self.use_cached_outputs,
199205
)
200206
self.assertEqual(res.seq_attr.shape, (4,))
201207
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
@@ -234,7 +240,10 @@ def test_llm_attr_fa_log_prob_gpu(self) -> None:
234240

235241
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
236242
res = llm_fa.attribute(
237-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
243+
inp,
244+
"m n o p q",
245+
skip_tokens=[0],
246+
use_cached_outputs=self.use_cached_outputs,
238247
)
239248

240249
# With FeatureAblation, the seq attr in log_prob
@@ -253,7 +262,10 @@ def test_llm_attr_without_token_gpu(
253262

254263
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
255264
res = llm_fa.attribute(
256-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
265+
inp,
266+
"m n o p q",
267+
skip_tokens=[0],
268+
use_cached_outputs=self.use_cached_outputs,
257269
)
258270

259271
self.assertEqual(res.seq_attr.shape, (4,))
@@ -280,7 +292,7 @@ def test_llm_attr(self) -> None:
280292
llm_attr = LLMGradientAttribution(attr, tokenizer)
281293

282294
inp = TextTokenInput("a b c", tokenizer)
283-
res = llm_attr.attribute(inp, "m n o p q")
295+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
284296
# 5 output tokens, 4 input tokens including sos
285297
self.assertEqual(res.seq_attr.shape, (4,))
286298
assert res.token_attr is not None # make pyre/mypy happy
@@ -324,7 +336,7 @@ def test_llm_attr_with_skip_tokens(self) -> None:
324336
llm_attr = LLMGradientAttribution(attr, tokenizer)
325337

326338
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
327-
res = llm_attr.attribute(inp, "m n o p q")
339+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
328340

329341
# 5 output tokens, 4 input tokens including sos
330342
self.assertEqual(res.seq_attr.shape, (3,))
@@ -338,3 +350,26 @@ def test_llm_attr_with_skip_tokens(self) -> None:
338350
self.assertEqual(res.seq_attr.device.type, self.device)
339351
assert res.token_attr is not None # make pyre/mypy happy
340352
self.assertEqual(token_attr.device.type, self.device) # type: ignore
353+
354+
def test_llm_attr_with_no_skip_tokens(self) -> None:
355+
llm = DummyLLM()
356+
llm.to(self.device)
357+
tokenizer = DummyTokenizer()
358+
attr = LayerIntegratedGradients(llm, llm.emb)
359+
llm_attr = LLMGradientAttribution(attr, tokenizer)
360+
361+
inp = TextTokenInput("a b c", tokenizer)
362+
res = llm_attr.attribute(inp, "m n o p q")
363+
364+
# 6 output tokens including sos, 4 input tokens including sos
365+
self.assertEqual(res.seq_attr.shape, (4,))
366+
assert res.token_attr is not None # make pyre/mypy happy
367+
self.assertIsNotNone(res.token_attr)
368+
token_attr = res.token_attr
369+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
370+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
371+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
372+
373+
self.assertEqual(res.seq_attr.device.type, self.device)
374+
assert res.token_attr is not None # make pyre/mypy happy
375+
self.assertEqual(token_attr.device.type, self.device) # type: ignore

0 commit comments

Comments
 (0)