Skip to content

Commit 7b80c5b

Browse files
csauperfacebook-github-bot
authored andcommitted
stop stripping first character from output string (#1351)
Summary: Pull Request resolved: #1351 previously first character was stripped as SOS token, but that doesn't actually seem to be the case with current LLMs. Keep all tokens. Reviewed By: craymichael Differential Revision: D62775617 fbshipit-source-id: 9a0edd83318ac39654a1526020111991005fa4f3
1 parent 70619a6 commit 7b80c5b

File tree

3 files changed

+165
-22
lines changed

3 files changed

+165
-22
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def attribute(
383383
self,
384384
inp: InterpretableInput,
385385
target: Union[str, torch.Tensor, None] = None,
386+
skip_tokens: Union[List[int], List[str], None] = None,
386387
num_trials: int = 1,
387388
gen_args: Optional[Dict[str, Any]] = None,
388389
use_cached_outputs: bool = True,
@@ -397,6 +398,12 @@ def attribute(
397398
which attributions are computed. If None, it uses the model
398399
to generate the target based on the input and gen_args.
399400
Default: None
401+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
402+
the output's interpretable representation. Use this argument to
403+
define uninterested tokens, commonly like special tokens, e.g.,
404+
sos, and unk. It can be a list of strings of the tokens or a list
405+
of integers of the token ids.
406+
Default: None
400407
num_trials (int, optional): number of trials to run. Return is the average
401408
attribibutions over all the trials.
402409
Defaults: 1.
@@ -433,13 +440,23 @@ def attribute(
433440
target_tokens = output_tokens[0][model_inp.size(1) :]
434441
else:
435442
assert gen_args is None, "gen_args must be None when target is given"
443+
# Encode skip tokens
444+
if skip_tokens:
445+
if isinstance(skip_tokens[0], str):
446+
skip_tokens = cast(List[str], skip_tokens)
447+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
448+
else:
449+
skip_tokens = []
436450

437-
if type(target) is str:
438-
# exclude sos
439-
target_tokens = self.tokenizer.encode(target)[1:]
440-
target_tokens = torch.tensor(target_tokens)
441-
elif type(target) is torch.Tensor:
442-
target_tokens = target
451+
if isinstance(target, str):
452+
encoded = self.tokenizer.encode(target)
453+
target_tokens = torch.tensor(
454+
[token for token in encoded if token not in skip_tokens]
455+
)
456+
elif isinstance(target, torch.Tensor):
457+
target_tokens = target[
458+
~torch.isin(target, torch.tensor(skip_tokens, device=target.device))
459+
]
443460
else:
444461
raise TypeError(
445462
"target must either be str or Tensor, but the type of target is "
@@ -562,6 +579,7 @@ def attribute(
562579
self,
563580
inp: InterpretableInput,
564581
target: Union[str, torch.Tensor, None] = None,
582+
skip_tokens: Union[List[int], List[str], None] = None,
565583
gen_args: Optional[Dict[str, Any]] = None,
566584
**kwargs: Any,
567585
) -> LLMAttributionResult:
@@ -572,6 +590,12 @@ def attribute(
572590
which attributions are computed. If None, it uses the model
573591
to generate the target based on the input and gen_args.
574592
Default: None
593+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
594+
the output's interpretable representation. Use this argument to
595+
define uninterested tokens, commonly like special tokens, e.g.,
596+
sos, and unk. It can be a list of strings of the tokens or a list
597+
of integers of the token ids.
598+
Default: None
575599
gen_args (dict, optional): arguments for generating the target. Only used if
576600
target is not given. When None, the default arguments are used,
577601
{"max_new_tokens": 25, "do_sample": False,
@@ -605,13 +629,23 @@ def attribute(
605629
target_tokens = output_tokens[0][model_inp.size(1) :]
606630
else:
607631
assert gen_args is None, "gen_args must be None when target is given"
632+
# Encode skip tokens
633+
if skip_tokens:
634+
if isinstance(skip_tokens[0], str):
635+
skip_tokens = cast(List[str], skip_tokens)
636+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
637+
else:
638+
skip_tokens = []
608639

609-
if type(target) is str:
610-
# exclude sos
611-
target_tokens = self.tokenizer.encode(target)[1:]
612-
target_tokens = torch.tensor(target_tokens)
613-
elif type(target) is torch.Tensor:
614-
target_tokens = target
640+
if isinstance(target, str):
641+
encoded = self.tokenizer.encode(target)
642+
target_tokens = torch.tensor(
643+
[token for token in encoded if token not in skip_tokens]
644+
)
645+
elif isinstance(target, torch.Tensor):
646+
target_tokens = target[
647+
~torch.isin(target, torch.tensor(skip_tokens, device=target.device))
648+
]
615649
else:
616650
raise TypeError(
617651
"target must either be str or Tensor, but the type of target is "

tests/attr/test_llm_attr.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def convert_tokens_to_ids(
9595
raise NotImplementedError
9696

9797
def decode(self, token_ids: Tensor) -> str:
98-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
98+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
99+
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
100+
return tokens if isinstance(tokens, str) else " ".join(tokens)
99101

100102

101103
class Result(NamedTuple):
@@ -271,6 +273,7 @@ def test_llm_attr(
271273
res = llm_attr.attribute(
272274
inp,
273275
"m n o p q",
276+
skip_tokens=[0],
274277
use_cached_outputs=self.use_cached_outputs,
275278
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
276279
# for 4th positional argument, expected
@@ -330,7 +333,10 @@ def test_llm_attr_fa_log_prob(self) -> None:
330333

331334
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
332335
res = llm_fa.attribute(
333-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
336+
inp,
337+
"m n o p q",
338+
skip_tokens=[0],
339+
use_cached_outputs=self.use_cached_outputs,
334340
)
335341

336342
# With FeatureAblation, the seq attr in log_prob
@@ -385,6 +391,7 @@ def test_llm_attr_without_token(
385391
res = llm_fa.attribute(
386392
inp,
387393
"m n o p q",
394+
skip_tokens=[0],
388395
use_cached_outputs=self.use_cached_outputs,
389396
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
390397
# for 4th positional argument, expected
@@ -416,6 +423,52 @@ def test_futures_not_implemented(self) -> None:
416423
attributions = llm_fa.attribute_future()
417424
self.assertEqual(attributions, None)
418425

426+
def test_llm_attr_with_no_skip_tokens(self) -> None:
427+
llm = DummyLLM()
428+
llm.to(self.device)
429+
tokenizer = DummyTokenizer()
430+
fa = FeatureAblation(llm)
431+
llm_fa = LLMAttribution(fa, tokenizer)
432+
433+
inp = TextTokenInput("a b c", tokenizer)
434+
res = llm_fa.attribute(
435+
inp,
436+
"m n o p q",
437+
use_cached_outputs=self.use_cached_outputs,
438+
)
439+
440+
# 5 output tokens, 4 input tokens including sos
441+
self.assertEqual(res.seq_attr.shape, (4,))
442+
assert res.token_attr is not None # make pyre/mypy happy
443+
self.assertIsNotNone(res.token_attr)
444+
token_attr = res.token_attr
445+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
446+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
447+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
448+
449+
def test_llm_attr_with_skip_tensor_target(self) -> None:
450+
llm = DummyLLM()
451+
llm.to(self.device)
452+
tokenizer = DummyTokenizer()
453+
fa = FeatureAblation(llm)
454+
llm_fa = LLMAttribution(fa, tokenizer)
455+
456+
inp = TextTokenInput("a b c", tokenizer)
457+
res = llm_fa.attribute(
458+
inp,
459+
torch.tensor(tokenizer.encode("m n o p q")),
460+
skip_tokens=[0],
461+
)
462+
463+
# 5 output tokens, 4 input tokens including sos
464+
self.assertEqual(res.seq_attr.shape, (4,))
465+
assert res.token_attr is not None # make pyre/mypy happy
466+
self.assertIsNotNone(res.token_attr)
467+
token_attr = res.token_attr
468+
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
469+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
470+
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
471+
419472

420473
@parameterized_class(
421474
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
@@ -448,7 +501,7 @@ def test_llm_attr(
448501
)
449502

450503
inp = TextTokenInput("a b c", tokenizer)
451-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
504+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
452505

453506
# 5 output tokens, 4 input tokens including sos
454507
self.assertEqual(res.seq_attr.shape, (4,))
@@ -523,7 +576,7 @@ def test_llm_attr_with_skip_tokens(
523576
)
524577

525578
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
526-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
579+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
527580

528581
# 5 output tokens, 4 input tokens including sos
529582
self.assertEqual(res.seq_attr.shape, (3,))
@@ -537,3 +590,48 @@ def test_llm_attr_with_skip_tokens(
537590
self.assertEqual(res.seq_attr.device.type, self.device)
538591
assert res.token_attr is not None # make pyre/mypy happy
539592
self.assertEqual(token_attr.device.type, self.device) # type: ignore
593+
594+
def test_llm_attr_with_no_skip_tokens(self) -> None:
595+
llm = DummyLLM()
596+
llm.to(self.device)
597+
tokenizer = DummyTokenizer()
598+
attr = LayerIntegratedGradients(llm, llm.emb) # type: ignore[call-arg]
599+
llm_attr = LLMGradientAttribution(attr, tokenizer)
600+
601+
attr_kws: Dict[str, Any] = {}
602+
inp = TextTokenInput("a b c", tokenizer)
603+
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
604+
605+
# 5 output tokens, 4 input tokens including sos
606+
self.assertEqual(res.seq_attr.shape, (4,))
607+
assert res.token_attr is not None # make pyre/mypy happy
608+
self.assertIsNotNone(res.token_attr)
609+
token_attr = res.token_attr
610+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
611+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
612+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
613+
614+
def test_llm_attr_with_skip_tensor_target(self) -> None:
615+
llm = DummyLLM()
616+
llm.to(self.device)
617+
tokenizer = DummyTokenizer()
618+
attr = LayerIntegratedGradients(llm, llm.emb) # type: ignore[call-arg]
619+
llm_attr = LLMGradientAttribution(attr, tokenizer)
620+
621+
attr_kws: Dict[str, Any] = {}
622+
inp = TextTokenInput("a b c", tokenizer)
623+
res = llm_attr.attribute(
624+
inp,
625+
torch.tensor(tokenizer.encode("m n o p q")),
626+
skip_tokens=[0],
627+
**attr_kws,
628+
)
629+
630+
# 5 output tokens, 4 input tokens including sos
631+
self.assertEqual(res.seq_attr.shape, (4,))
632+
assert res.token_attr is not None # make pyre/mypy happy
633+
self.assertIsNotNone(res.token_attr)
634+
token_attr = res.token_attr
635+
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
636+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
637+
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

tests/attr/test_llm_attr_gpu.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def convert_tokens_to_ids(
8484
raise NotImplementedError
8585

8686
def decode(self, token_ids: Tensor) -> str:
87-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
87+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
88+
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
89+
return tokens if isinstance(tokens, str) else " ".join(tokens)
8890

8991

9092
class Result(NamedTuple):
@@ -195,7 +197,10 @@ def test_llm_attr_gpu(self, AttrClass: Type[PerturbationAttribution]) -> None:
195197

196198
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
197199
res = llm_attr.attribute(
198-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
200+
inp,
201+
"m n o p q",
202+
skip_tokens=[0],
203+
use_cached_outputs=self.use_cached_outputs,
199204
)
200205
self.assertEqual(res.seq_attr.shape, (4,))
201206
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
@@ -234,7 +239,10 @@ def test_llm_attr_fa_log_prob_gpu(self) -> None:
234239

235240
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
236241
res = llm_fa.attribute(
237-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
242+
inp,
243+
"m n o p q",
244+
skip_tokens=[0],
245+
use_cached_outputs=self.use_cached_outputs,
238246
)
239247

240248
# With FeatureAblation, the seq attr in log_prob
@@ -253,7 +261,10 @@ def test_llm_attr_without_token_gpu(
253261

254262
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
255263
res = llm_fa.attribute(
256-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
264+
inp,
265+
"m n o p q",
266+
skip_tokens=[0],
267+
use_cached_outputs=self.use_cached_outputs,
257268
)
258269

259270
self.assertEqual(res.seq_attr.shape, (4,))
@@ -280,7 +291,7 @@ def test_llm_attr(self) -> None:
280291
llm_attr = LLMGradientAttribution(attr, tokenizer)
281292

282293
inp = TextTokenInput("a b c", tokenizer)
283-
res = llm_attr.attribute(inp, "m n o p q")
294+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
284295
# 5 output tokens, 4 input tokens including sos
285296
self.assertEqual(res.seq_attr.shape, (4,))
286297
assert res.token_attr is not None # make pyre/mypy happy
@@ -324,7 +335,7 @@ def test_llm_attr_with_skip_tokens(self) -> None:
324335
llm_attr = LLMGradientAttribution(attr, tokenizer)
325336

326337
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
327-
res = llm_attr.attribute(inp, "m n o p q")
338+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
328339

329340
# 5 output tokens, 4 input tokens including sos
330341
self.assertEqual(res.seq_attr.shape, (3,))

0 commit comments

Comments
 (0)