88 cast ,
99 Dict ,
1010 List ,
11+ Literal ,
1112 NamedTuple ,
1213 Optional ,
1314 overload ,
1819
1920import torch
2021from captum ._utils .models .linear_model import SkLearnLasso
21- from captum ._utils .typing import Literal
2222from captum .attr ._core .feature_ablation import FeatureAblation
2323from captum .attr ._core .kernel_shap import KernelShap
2424from captum .attr ._core .layer .layer_gradient_shap import LayerGradientShap
@@ -44,9 +44,6 @@ class DummyTokenizer:
4444 @overload
4545 def encode (self , text : str , return_tensors : None = None ) -> List [int ]: ...
4646 @overload
47- # pyre-fixme[43]: Incompatible overload. The implementation of
48- # `DummyTokenizer.encode` does not accept all possible arguments of overload.
49- # pyre-ignore[11]: Annotation `pt` is not defined as a type
5047 def encode (self , text : str , return_tensors : Literal ["pt" ]) -> Tensor : ...
5148
5249 def encode (
@@ -393,9 +390,6 @@ def test_llm_attr_without_token(
393390 "m n o p q" ,
394391 skip_tokens = [0 ],
395392 use_cached_outputs = self .use_cached_outputs ,
396- # pyre-fixme[6]: In call `LLMAttribution.attribute`,
397- # for 4th positional argument, expected
398- # `Optional[typing.Callable[..., typing.Any]]` but got `int`.
399393 ** attr_kws , # type: ignore
400394 )
401395
@@ -439,10 +433,10 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
439433
440434 # 5 output tokens, 4 input tokens including sos
441435 self .assertEqual (res .seq_attr .shape , (4 ,))
442- assert res .token_attr is not None # make pyre/mypy happy
436+ assert res .token_attr is not None
443437 self .assertIsNotNone (res .token_attr )
444438 token_attr = res .token_attr
445- self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
439+ self .assertEqual (token_attr .shape , (6 , 4 ))
446440 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
447441 self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
448442
@@ -462,18 +456,17 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
462456
463457 # 5 output tokens, 4 input tokens including sos
464458 self .assertEqual (res .seq_attr .shape , (4 ,))
465- assert res .token_attr is not None # make pyre/mypy happy
459+ assert res .token_attr is not None
466460 self .assertIsNotNone (res .token_attr )
467461 token_attr = res .token_attr
468- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
462+ self .assertEqual (token_attr .shape , (5 , 4 ))
469463 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
470464 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
471465
472466
473467@parameterized_class (
474468 ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
475469)
476- # pyre-fixme[13]: Attribute `device` is never initialized.
477470class TestLLMGradAttr (BaseTest ):
478471 # pyre-fixme[13]: Attribute `device` is never initialized.
479472 device : str
@@ -505,16 +498,16 @@ def test_llm_attr(
505498
506499 # 5 output tokens, 4 input tokens including sos
507500 self .assertEqual (res .seq_attr .shape , (4 ,))
508- assert res .token_attr is not None # make pyre/mypy happy
501+ assert res .token_attr is not None
509502 self .assertIsNotNone (res .token_attr )
510503 token_attr = res .token_attr
511- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
504+ self .assertEqual (token_attr .shape , (5 , 4 ))
512505 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
513506 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
514507
515508 self .assertEqual (res .seq_attr .device .type , self .device )
516- assert res .token_attr is not None # make pyre/mypy happy
517- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
509+ assert res .token_attr is not None
510+ self .assertEqual (token_attr .device .type , self .device )
518511
519512 @parameterized .expand (
520513 [
@@ -542,16 +535,16 @@ def test_llm_attr_without_target(
542535 res = llm_attr .attribute (inp , gen_args = {"mock_response" : "x y z" }, ** attr_kws )
543536
544537 self .assertEqual (res .seq_attr .shape , (4 ,))
545- assert res .token_attr is not None # make pyre/mypy happy
538+ assert res .token_attr is not None
546539 self .assertIsNotNone (res .token_attr )
547540 token_attr = res .token_attr
548- self .assertEqual (token_attr .shape , (3 , 4 )) # type: ignore
541+ self .assertEqual (token_attr .shape , (3 , 4 ))
549542 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
550543 self .assertEqual (res .output_tokens , ["x" , "y" , "z" ])
551544
552545 self .assertEqual (res .seq_attr .device .type , self .device )
553- assert res .token_attr is not None # make pyre/mypy happy
554- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
546+ assert res .token_attr is not None
547+ self .assertEqual (token_attr .device .type , self .device )
555548
556549 @parameterized .expand (
557550 [
@@ -580,16 +573,16 @@ def test_llm_attr_with_skip_tokens(
580573
581574 # 5 output tokens, 4 input tokens including sos
582575 self .assertEqual (res .seq_attr .shape , (3 ,))
583- assert res .token_attr is not None # make pyre/mypy happy
576+ assert res .token_attr is not None
584577 self .assertIsNotNone (res .token_attr )
585578 token_attr = res .token_attr
586- self .assertEqual (token_attr .shape , (5 , 3 )) # type: ignore
579+ self .assertEqual (token_attr .shape , (5 , 3 ))
587580 self .assertEqual (res .input_tokens , ["a" , "b" , "c" ])
588581 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
589582
590583 self .assertEqual (res .seq_attr .device .type , self .device )
591- assert res .token_attr is not None # make pyre/mypy happy
592- self .assertEqual (token_attr .device .type , self .device ) # type: ignore
584+ assert res .token_attr is not None
585+ self .assertEqual (token_attr .device .type , self .device )
593586
594587 def test_llm_attr_with_no_skip_tokens (self ) -> None :
595588 llm = DummyLLM ()
@@ -602,12 +595,12 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
602595 inp = TextTokenInput ("a b c" , tokenizer )
603596 res = llm_attr .attribute (inp , "m n o p q" , ** attr_kws )
604597
605- # 5 output tokens, 4 input tokens including sos
598+ # 6 output tokens, 4 input tokens including sos
606599 self .assertEqual (res .seq_attr .shape , (4 ,))
607- assert res .token_attr is not None # make pyre/mypy happy
600+ assert res .token_attr is not None
608601 self .assertIsNotNone (res .token_attr )
609602 token_attr = res .token_attr
610- self .assertEqual (token_attr .shape , (6 , 4 )) # type: ignore
603+ self .assertEqual (token_attr .shape , (6 , 4 ))
611604 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
612605 self .assertEqual (res .output_tokens , ["<sos>" , "m" , "n" , "o" , "p" , "q" ])
613606
@@ -629,9 +622,9 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
629622
630623 # 5 output tokens, 4 input tokens including sos
631624 self .assertEqual (res .seq_attr .shape , (4 ,))
632- assert res .token_attr is not None # make pyre/mypy happy
625+ assert res .token_attr is not None
633626 self .assertIsNotNone (res .token_attr )
634627 token_attr = res .token_attr
635- self .assertEqual (token_attr .shape , (5 , 4 )) # type: ignore
628+ self .assertEqual (token_attr .shape , (5 , 4 ))
636629 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
637630 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
0 commit comments