Skip to content

Commit 3db2465

Browse files
csauperfacebook-github-bot
authored andcommitted
stop stripping first character from output string (meta-pytorch#1351)
Summary: Pull Request resolved: meta-pytorch#1351 previously first character was stripped as SOS token, but that doesn't actually seem to be the case with current LLMs. Keep all tokens. Differential Revision: D62775617
1 parent 70619a6 commit 3db2465

File tree

3 files changed

+131
-16
lines changed

3 files changed

+131
-16
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def attribute(
383383
self,
384384
inp: InterpretableInput,
385385
target: Union[str, torch.Tensor, None] = None,
386+
skip_tokens: Union[List[int], List[str], None] = None,
386387
num_trials: int = 1,
387388
gen_args: Optional[Dict[str, Any]] = None,
388389
use_cached_outputs: bool = True,
@@ -397,6 +398,12 @@ def attribute(
397398
which attributions are computed. If None, it uses the model
398399
to generate the target based on the input and gen_args.
399400
Default: None
401+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
402+
the output's interpretable representation. Use this argument to
403+
define uninterested tokens, commonly like special tokens, e.g.,
404+
sos, and unk. It can be a list of strings of the tokens or a list
405+
of integers of the token ids.
406+
Default: None
400407
num_trials (int, optional): number of trials to run. Return is the average
401408
attribibutions over all the trials.
402409
Defaults: 1.
@@ -435,9 +442,20 @@ def attribute(
435442
assert gen_args is None, "gen_args must be None when target is given"
436443

437444
if type(target) is str:
438-
# exclude sos
439-
target_tokens = self.tokenizer.encode(target)[1:]
440-
target_tokens = torch.tensor(target_tokens)
445+
encoded = self.tokenizer.encode(target)
446+
447+
if skip_tokens:
448+
if isinstance(skip_tokens[0], str):
449+
skip_tokens = cast(List[str], skip_tokens)
450+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
451+
assert isinstance(skip_tokens, list)
452+
453+
skip_token_set = set(skip_tokens)
454+
encoded = [
455+
token for token in encoded if token not in skip_token_set
456+
]
457+
458+
target_tokens = torch.tensor(encoded)
441459
elif type(target) is torch.Tensor:
442460
target_tokens = target
443461
else:
@@ -562,6 +580,7 @@ def attribute(
562580
self,
563581
inp: InterpretableInput,
564582
target: Union[str, torch.Tensor, None] = None,
583+
skip_tokens: Union[List[int], List[str], None] = None,
565584
gen_args: Optional[Dict[str, Any]] = None,
566585
**kwargs: Any,
567586
) -> LLMAttributionResult:
@@ -572,6 +591,12 @@ def attribute(
572591
which attributions are computed. If None, it uses the model
573592
to generate the target based on the input and gen_args.
574593
Default: None
594+
skip_tokens (List[int] or List[str], optional): the tokens to skip in the
595+
the output's interpretable representation. Use this argument to
596+
define uninterested tokens, commonly like special tokens, e.g.,
597+
sos, and unk. It can be a list of strings of the tokens or a list
598+
of integers of the token ids.
599+
Default: None
575600
gen_args (dict, optional): arguments for generating the target. Only used if
576601
target is not given. When None, the default arguments are used,
577602
{"max_new_tokens": 25, "do_sample": False,
@@ -607,9 +632,20 @@ def attribute(
607632
assert gen_args is None, "gen_args must be None when target is given"
608633

609634
if type(target) is str:
610-
# exclude sos
611-
target_tokens = self.tokenizer.encode(target)[1:]
612-
target_tokens = torch.tensor(target_tokens)
635+
encoded = self.tokenizer.encode(target)
636+
637+
if skip_tokens:
638+
if isinstance(skip_tokens[0], str):
639+
skip_tokens = cast(List[str], skip_tokens)
640+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
641+
assert isinstance(skip_tokens, list)
642+
643+
skip_token_set = set(skip_tokens)
644+
encoded = [
645+
token for token in encoded if token not in skip_token_set
646+
]
647+
648+
target_tokens = torch.tensor(encoded)
613649
elif type(target) is torch.Tensor:
614650
target_tokens = target
615651
else:

tests/attr/test_llm_attr.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def convert_tokens_to_ids(
9595
raise NotImplementedError
9696

9797
def decode(self, token_ids: Tensor) -> str:
98-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
98+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
99+
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
100+
return tokens if isinstance(tokens, str) else " ".join(tokens)
99101

100102

101103
class Result(NamedTuple):
@@ -271,6 +273,7 @@ def test_llm_attr(
271273
res = llm_attr.attribute(
272274
inp,
273275
"m n o p q",
276+
skip_tokens=[0],
274277
use_cached_outputs=self.use_cached_outputs,
275278
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
276279
# for 4th positional argument, expected
@@ -330,7 +333,10 @@ def test_llm_attr_fa_log_prob(self) -> None:
330333

331334
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
332335
res = llm_fa.attribute(
333-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
336+
inp,
337+
"m n o p q",
338+
skip_tokens=[0],
339+
use_cached_outputs=self.use_cached_outputs,
334340
)
335341

336342
# With FeatureAblation, the seq attr in log_prob
@@ -385,6 +391,7 @@ def test_llm_attr_without_token(
385391
res = llm_fa.attribute(
386392
inp,
387393
"m n o p q",
394+
skip_tokens=[0],
388395
use_cached_outputs=self.use_cached_outputs,
389396
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
390397
# for 4th positional argument, expected
@@ -448,7 +455,7 @@ def test_llm_attr(
448455
)
449456

450457
inp = TextTokenInput("a b c", tokenizer)
451-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
458+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
452459

453460
# 5 output tokens, 4 input tokens including sos
454461
self.assertEqual(res.seq_attr.shape, (4,))
@@ -523,7 +530,7 @@ def test_llm_attr_with_skip_tokens(
523530
)
524531

525532
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
526-
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
533+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0], **attr_kws)
527534

528535
# 5 output tokens, 4 input tokens including sos
529536
self.assertEqual(res.seq_attr.shape, (3,))
@@ -537,3 +544,41 @@ def test_llm_attr_with_skip_tokens(
537544
self.assertEqual(res.seq_attr.device.type, self.device)
538545
assert res.token_attr is not None # make pyre/mypy happy
539546
self.assertEqual(token_attr.device.type, self.device) # type: ignore
547+
548+
@parameterized.expand(
549+
[
550+
(LayerIntegratedGradients, None),
551+
(LayerGradientXActivation, None),
552+
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
553+
]
554+
)
555+
def test_llm_attr_with_no_skip_tokens(
556+
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
557+
) -> None:
558+
llm = DummyLLM()
559+
llm.to(self.device)
560+
tokenizer = DummyTokenizer()
561+
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
562+
llm_attr = LLMGradientAttribution(attr, tokenizer)
563+
564+
attr_kws: Dict[str, Any] = {}
565+
if baselines is not None:
566+
attr_kws["baselines"] = tuple(
567+
baseline.to(self.device) for baseline in baselines
568+
)
569+
570+
inp = TextTokenInput("a b c", tokenizer)
571+
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
572+
573+
# 5 output tokens, 4 input tokens including sos
574+
self.assertEqual(res.seq_attr.shape, (4,))
575+
assert res.token_attr is not None # make pyre/mypy happy
576+
self.assertIsNotNone(res.token_attr)
577+
token_attr = res.token_attr
578+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
579+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
580+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
581+
582+
self.assertEqual(res.seq_attr.device.type, self.device)
583+
assert res.token_attr is not None # make pyre/mypy happy
584+
self.assertEqual(token_attr.device.type, self.device) # type: ignore

tests/attr/test_llm_attr_gpu.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def convert_tokens_to_ids(
8484
raise NotImplementedError
8585

8686
def decode(self, token_ids: Tensor) -> str:
87-
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))
87+
tokens = self.convert_ids_to_tokens(token_ids.tolist())
88+
# pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
89+
return tokens if isinstance(tokens, str) else " ".join(tokens)
8890

8991

9092
class Result(NamedTuple):
@@ -195,7 +197,10 @@ def test_llm_attr_gpu(self, AttrClass: Type[PerturbationAttribution]) -> None:
195197

196198
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
197199
res = llm_attr.attribute(
198-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
200+
inp,
201+
"m n o p q",
202+
skip_tokens=[0],
203+
use_cached_outputs=self.use_cached_outputs,
199204
)
200205
self.assertEqual(res.seq_attr.shape, (4,))
201206
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
@@ -234,7 +239,10 @@ def test_llm_attr_fa_log_prob_gpu(self) -> None:
234239

235240
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
236241
res = llm_fa.attribute(
237-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
242+
inp,
243+
"m n o p q",
244+
skip_tokens=[0],
245+
use_cached_outputs=self.use_cached_outputs,
238246
)
239247

240248
# With FeatureAblation, the seq attr in log_prob
@@ -253,7 +261,10 @@ def test_llm_attr_without_token_gpu(
253261

254262
inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
255263
res = llm_fa.attribute(
256-
inp, "m n o p q", use_cached_outputs=self.use_cached_outputs
264+
inp,
265+
"m n o p q",
266+
skip_tokens=[0],
267+
use_cached_outputs=self.use_cached_outputs,
257268
)
258269

259270
self.assertEqual(res.seq_attr.shape, (4,))
@@ -280,7 +291,7 @@ def test_llm_attr(self) -> None:
280291
llm_attr = LLMGradientAttribution(attr, tokenizer)
281292

282293
inp = TextTokenInput("a b c", tokenizer)
283-
res = llm_attr.attribute(inp, "m n o p q")
294+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
284295
# 5 output tokens, 4 input tokens including sos
285296
self.assertEqual(res.seq_attr.shape, (4,))
286297
assert res.token_attr is not None # make pyre/mypy happy
@@ -324,7 +335,7 @@ def test_llm_attr_with_skip_tokens(self) -> None:
324335
llm_attr = LLMGradientAttribution(attr, tokenizer)
325336

326337
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
327-
res = llm_attr.attribute(inp, "m n o p q")
338+
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
328339

329340
# 5 output tokens, 4 input tokens including sos
330341
self.assertEqual(res.seq_attr.shape, (3,))
@@ -338,3 +349,26 @@ def test_llm_attr_with_skip_tokens(self) -> None:
338349
self.assertEqual(res.seq_attr.device.type, self.device)
339350
assert res.token_attr is not None # make pyre/mypy happy
340351
self.assertEqual(token_attr.device.type, self.device) # type: ignore
352+
353+
def test_llm_attr_with_no_skip_tokens(self) -> None:
354+
llm = DummyLLM()
355+
llm.to(self.device)
356+
tokenizer = DummyTokenizer()
357+
attr = LayerIntegratedGradients(llm, llm.emb)
358+
llm_attr = LLMGradientAttribution(attr, tokenizer)
359+
360+
inp = TextTokenInput("a b c", tokenizer)
361+
res = llm_attr.attribute(inp, "m n o p q")
362+
363+
# 6 output tokens including sos, 4 input tokens including sos
364+
self.assertEqual(res.seq_attr.shape, (4,))
365+
assert res.token_attr is not None # make pyre/mypy happy
366+
self.assertIsNotNone(res.token_attr)
367+
token_attr = res.token_attr
368+
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
369+
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
370+
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])
371+
372+
self.assertEqual(res.seq_attr.device.type, self.device)
373+
assert res.token_attr is not None # make pyre/mypy happy
374+
self.assertEqual(token_attr.device.type, self.device) # type: ignore

0 commit comments

Comments
 (0)