@@ -95,7 +95,9 @@ def convert_tokens_to_ids(
9595 raise NotImplementedError
9696
9797 def decode (self , token_ids : Tensor ) -> str :
98- return " " .join (self .convert_ids_to_tokens (token_ids .tolist ()))
98+ tokens = self .convert_ids_to_tokens (token_ids .tolist ())
99+ # pyre-fixme[7]: Expected `str` but got `Union[List[str], str]`.
100+ return tokens if isinstance (tokens , str ) else " " .join (tokens )
99101
100102
101103class Result (NamedTuple ):
@@ -271,6 +273,7 @@ def test_llm_attr(
271273 res = llm_attr .attribute (
272274 inp ,
273275 "m n o p q" ,
276+ skip_tokens = [0 ],
274277 use_cached_outputs = self .use_cached_outputs ,
275278 # pyre-fixme[6]: In call `LLMAttribution.attribute`,
276279 # for 4th positional argument, expected
@@ -330,7 +333,10 @@ def test_llm_attr_fa_log_prob(self) -> None:
330333
331334 inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
332335 res = llm_fa .attribute (
333- inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
336+ inp ,
337+ "m n o p q" ,
338+ skip_tokens = [0 ],
339+ use_cached_outputs = self .use_cached_outputs ,
334340 )
335341
336342 # With FeatureAblation, the seq attr in log_prob
@@ -385,6 +391,7 @@ def test_llm_attr_without_token(
385391 res = llm_fa .attribute (
386392 inp ,
387393 "m n o p q" ,
394+ skip_tokens = [0 ],
388395 use_cached_outputs = self .use_cached_outputs ,
389396 # pyre-fixme[6]: In call `LLMAttribution.attribute`,
390397 # for 4th positional argument, expected
@@ -416,6 +423,52 @@ def test_futures_not_implemented(self) -> None:
416423 attributions = llm_fa .attribute_future ()
417424 self .assertEqual (attributions , None )
418425
426+ def test_llm_attr_with_no_skip_tokens (self ) -> None :
427+ llm = DummyLLM ()
428+ llm .to (self .device )
429+ tokenizer = DummyTokenizer ()
430+ fa = FeatureAblation (llm )
431+ llm_fa = LLMAttribution (fa , tokenizer )
432+
433+ inp = TextTokenInput ("a b c" , tokenizer )
434+ res = llm_fa .attribute (
435+ inp ,
436+ "m n o p q" ,
437+ use_cached_outputs = self .use_cached_outputs ,
438+ )
439+
440+ # 5 output tokens, 4 input tokens including sos
441+ self .assertEqual (res .seq_attr .shape , (4 ,))
442+ assert res .token_attr is not None # make pyre/mypy happy
443+ self .assertIsNotNone (res .token_attr )
444+ token_attr = res .token_attr
445+ self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
446+ self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
447+ self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
448+
449+ def test_llm_attr_with_skip_tensor_target (self ) -> None :
450+ llm = DummyLLM ()
451+ llm .to (self .device )
452+ tokenizer = DummyTokenizer ()
453+ fa = FeatureAblation (llm )
454+ llm_fa = LLMAttribution (fa , tokenizer )
455+
456+ inp = TextTokenInput ("a b c" , tokenizer )
457+ res = llm_fa .attribute (
458+ inp ,
459+ torch .tensor (tokenizer .encode ("m n o p q" )),
460+ skip_tokens = [0 ],
461+ )
462+
463+ # 5 output tokens, 4 input tokens including sos
464+ self .assertEqual (res .seq_attr .shape , (4 ,))
465+ assert res .token_attr is not None # make pyre/mypy happy
466+ self .assertIsNotNone (res .token_attr )
467+ token_attr = res .token_attr
468+ self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
469+ self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
470+ self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
471+
419472
420473@parameterized_class (
421474 ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
@@ -448,7 +501,7 @@ def test_llm_attr(
448501 )
449502
450503 inp = TextTokenInput ("a b c" , tokenizer )
451- res = llm_attr .attribute (inp , "m n o p q" , ** attr_kws )
504+ res = llm_attr .attribute (inp , "m n o p q" , skip_tokens = [ 0 ], ** attr_kws )
452505
453506 # 5 output tokens, 4 input tokens including sos
454507 self .assertEqual (res .seq_attr .shape , (4 ,))
@@ -523,7 +576,7 @@ def test_llm_attr_with_skip_tokens(
523576 )
524577
525578 inp = TextTokenInput ("a b c" , tokenizer , skip_tokens = [0 ])
526- res = llm_attr .attribute (inp , "m n o p q" , ** attr_kws )
579+ res = llm_attr .attribute (inp , "m n o p q" , skip_tokens = [ 0 ], ** attr_kws )
527580
528581 # 5 output tokens, 4 input tokens including sos
529582 self .assertEqual (res .seq_attr .shape , (3 ,))
@@ -537,3 +590,48 @@ def test_llm_attr_with_skip_tokens(
537590 self .assertEqual (res .seq_attr .device .type , self .device )
538591 assert res .token_attr is not None # make pyre/mypy happy
539592 self .assertEqual (token_attr .device .type , self .device ) # type: ignore
593+
594+ def test_llm_attr_with_no_skip_tokens (self ) -> None :
595+ llm = DummyLLM ()
596+ llm .to (self .device )
597+ tokenizer = DummyTokenizer ()
598+ attr = LayerIntegratedGradients (llm , llm .emb ) # type: ignore[call-arg]
599+ llm_attr = LLMGradientAttribution (attr , tokenizer )
600+
601+ attr_kws : Dict [str , Any ] = {}
602+ inp = TextTokenInput ("a b c" , tokenizer )
603+ res = llm_attr .attribute (inp , "m n o p q" , ** attr_kws )
604+
605+ # 5 output tokens, 4 input tokens including sos
606+ self .assertEqual (res .seq_attr .shape , (4 ,))
607+ assert res .token_attr is not None # make pyre/mypy happy
608+ self .assertIsNotNone (res .token_attr )
609+ token_attr = res .token_attr
610+ self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
611+ self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
612+ self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
613+
614+ def test_llm_attr_with_skip_tensor_target (self ) -> None :
615+ llm = DummyLLM ()
616+ llm .to (self .device )
617+ tokenizer = DummyTokenizer ()
618+ attr = LayerIntegratedGradients (llm , llm .emb ) # type: ignore[call-arg]
619+ llm_attr = LLMGradientAttribution (attr , tokenizer )
620+
621+ attr_kws : Dict [str , Any ] = {}
622+ inp = TextTokenInput ("a b c" , tokenizer )
623+ res = llm_attr .attribute (
624+ inp ,
625+ torch .tensor (tokenizer .encode ("m n o p q" )),
626+ skip_tokens = [0 ],
627+ ** attr_kws ,
628+ )
629+
630+ # 5 output tokens, 4 input tokens including sos
631+ self .assertEqual (res .seq_attr .shape , (4 ,))
632+ assert res .token_attr is not None # make pyre/mypy happy
633+ self .assertIsNotNone (res .token_attr )
634+ token_attr = res .token_attr
635+ self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
636+ self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
637+ self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
0 commit comments