From a242c87c0c8dd35c43389b047a084834201ffee6 Mon Sep 17 00:00:00 2001 From: Christy Sauper Date: Wed, 25 Sep 2024 10:57:10 -0700 Subject: [PATCH] Fix more pyre errors for llm_attr and tests [2/n] (#1359) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1359 Get rid of some more pyre errors by adding better typing Reviewed By: uberblah Differential Revision: D63365945 --- captum/attr/_core/llm_attr.py | 30 ++++++++---------- tests/attr/test_llm_attr.py | 51 ++++++++++++++----------------- tests/attr/test_llm_attr_gpu.py | 54 ++++++++++++++++----------------- 3 files changed, 61 insertions(+), 74 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index de33e0472d..70be740175 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -79,7 +79,7 @@ def plot_token_attr( "token_attr is None (no token-level attribution was performed), please " "use plot_seq_attr instead for the sequence-level attribution plot" ) - token_attr = self.token_attr.cpu() # type: ignore + token_attr = self.token_attr.cpu() # maximum absolute attribution value # used as the boundary of normalization @@ -343,7 +343,7 @@ def _forward_func( caching=use_cached_outputs, ) - log_prob_list = [] + log_prob_list: List[Tensor] = [] outputs = None for target_token in target_tokens: if use_cached_outputs: @@ -382,17 +382,15 @@ def _forward_func( (model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1 ) - # pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return - # Tensor - total_log_prob: Tensor = sum(log_prob_list) # type: ignore + total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: target_log_probs = torch.stack( - [total_log_prob, *log_prob_list], dim=0 # type: ignore + [total_log_prob, *log_prob_list], dim=0 ).unsqueeze(0) else: - target_log_probs = total_log_prob # type: ignore + target_log_probs = total_log_prob target_probs = torch.exp(target_log_probs) if _inspect_forward: @@ -412,9 +410,9 @@ def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor: """ # return tensor(1, n_tokens) if isinstance(model_input, str): - return self.tokenizer.encode( # type: ignore - model_input, return_tensors="pt" - ).to(self.device) + return self.tokenizer.encode(model_input, return_tensors="pt").to( + self.device + ) return model_input.to(self.device) def attribute( @@ -544,8 +542,7 @@ def attribute( _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> Callable[[], LLMAttributionResult]: r""" This method is not implemented for LLMAttribution. """ @@ -612,9 +609,9 @@ def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor: Convert str to tokenized tensor """ if isinstance(model_input, str): - return self.tokenizer.encode( # type: ignore - model_input, return_tensors="pt" - ).to(self.device) + return self.tokenizer.encode(model_input, return_tensors="pt").to( + self.device + ) return model_input.to(self.device) def attribute( @@ -745,8 +742,7 @@ def attribute( _convert_ids_to_pretty_tokens(target_tokens, self.tokenizer), ) - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future(self) -> Callable[[], LLMAttributionResult]: r""" This method is not implemented for LLMGradientAttribution. """ diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 3e9c500cbc..d22bef384b 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -8,6 +8,7 @@ cast, Dict, List, + Literal, NamedTuple, Optional, overload, @@ -18,7 +19,6 @@ import torch from captum._utils.models.linear_model import SkLearnLasso -from captum._utils.typing import Literal from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap @@ -44,9 +44,6 @@ class DummyTokenizer: @overload def encode(self, text: str, return_tensors: None = None) -> List[int]: ... @overload - # pyre-fixme[43]: Incompatible overload. The implementation of - # `DummyTokenizer.encode` does not accept all possible arguments of overload. - # pyre-ignore[11]: Annotation `pt` is not defined as a type def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... def encode( @@ -393,9 +390,6 @@ def test_llm_attr_without_token( "m n o p q", skip_tokens=[0], use_cached_outputs=self.use_cached_outputs, - # pyre-fixme[6]: In call `LLMAttribution.attribute`, - # for 4th positional argument, expected - # `Optional[typing.Callable[..., typing.Any]]` but got `int`. **attr_kws, # type: ignore ) @@ -439,10 +433,10 @@ def test_llm_attr_with_no_skip_tokens(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (6, 4)) # type: ignore + self.assertEqual(token_attr.shape, (6, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) @@ -462,10 +456,10 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 4)) # type: ignore + self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) @@ -473,7 +467,6 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: @parameterized_class( ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] ) -# pyre-fixme[13]: Attribute `device` is never initialized. class TestLLMGradAttr(BaseTest): # pyre-fixme[13]: Attribute `device` is never initialized. device: str @@ -505,16 +498,16 @@ def test_llm_attr( # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 4)) # type: ignore + self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device) @parameterized.expand( [ @@ -542,16 +535,16 @@ def test_llm_attr_without_target( res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws) self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (3, 4)) # type: ignore + self.assertEqual(token_attr.shape, (3, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) self.assertEqual(res.seq_attr.device.type, self.device) - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device) @parameterized.expand( [ @@ -580,16 +573,16 @@ def test_llm_attr_with_skip_tokens( # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (3,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 3)) # type: ignore + self.assertEqual(token_attr.shape, (5, 3)) self.assertEqual(res.input_tokens, ["a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device) def test_llm_attr_with_no_skip_tokens(self) -> None: llm = DummyLLM() @@ -602,12 +595,12 @@ def test_llm_attr_with_no_skip_tokens(self) -> None: inp = TextTokenInput("a b c", tokenizer) res = llm_attr.attribute(inp, "m n o p q", **attr_kws) - # 5 output tokens, 4 input tokens including sos + # 6 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (6, 4)) # type: ignore + self.assertEqual(token_attr.shape, (6, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) @@ -629,9 +622,9 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 4)) # type: ignore + self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) diff --git a/tests/attr/test_llm_attr_gpu.py b/tests/attr/test_llm_attr_gpu.py index f6b05aa187..a66fb7fff5 100644 --- a/tests/attr/test_llm_attr_gpu.py +++ b/tests/attr/test_llm_attr_gpu.py @@ -3,11 +3,21 @@ # pyre-strict import copy -from typing import Any, cast, Dict, List, NamedTuple, Optional, overload, Type, Union +from typing import ( + Any, + cast, + Dict, + List, + Literal, + NamedTuple, + Optional, + overload, + Type, + Union, +) import torch -from captum._utils.typing import Literal from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients @@ -32,9 +42,6 @@ class DummyTokenizer: @overload def encode(self, text: str, return_tensors: None = None) -> List[int]: ... @overload - # pyre-fixme[43]: Incompatible overload. The implementation of - # `DummyTokenizer.encode` does not accept all possible arguments of overload. - # pyre-ignore[11]: Annotation `pt` is not defined as a type def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... def encode( @@ -122,9 +129,6 @@ def generate( assert mock_response, "must mock response to use DummyLLM to geenrate" response = self.tokenizer.encode(mock_response)[1:] return torch.cat( - # pyre-fixme[6]: In call `torch._C._VariableFunctions.cat`, - # for 1st positional argument, expected `Union[List[Tensor], - # typing.Tuple[Tensor, ...]]` but got `List[Union[List[int], Tensor]]`. [input_ids, torch.tensor([response], device=self.device)], # type: ignore dim=1, ) @@ -178,10 +182,6 @@ def device(self) -> torch._C.device: else [("cpu", True), ("cpu", False)] ), ) -# pyre-fixme[13]: Attribute `device` is declared in class `TestLlmAttrGpu` -# to have type `str` but is never initialized. -# pyre-fixme[13]: Attribute `use_cached_outputs` is declared in class `TestLlmAttrGpu` -# to have type `bool` but is never initialized. class TestLlmAttrGpu(BaseTest): # pyre-fixme[13]: Attribute `device` is never initialized. device: str @@ -277,8 +277,6 @@ def test_llm_attr_without_token_gpu( @parameterized_class( ("device",), [("cuda",)] if torch.cuda.is_available() else [("cpu",)] ) -# pyre-fixme[13]: Attribute `device` is declared in class `TestLLMGradAttrGPU` -# to have type `str` but is never initialized. class TestLLMGradAttrGPU(BaseTest): # pyre-fixme[13]: Attribute `device` is never initialized. device: str @@ -294,16 +292,16 @@ def test_llm_attr(self) -> None: res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0]) # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 4)) # type: ignore + self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device) def test_llm_attr_without_target(self) -> None: llm = DummyLLM() @@ -316,16 +314,16 @@ def test_llm_attr_without_target(self) -> None: res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}) self.assertEqual(res.seq_attr.shape, (4,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (3, 4)) # type: ignore + self.assertEqual(token_attr.shape, (3, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) - self.assertEqual(res.seq_attr.device.type, self.device) # type: ignore - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + self.assertEqual(res.seq_attr.device.type, self.device) + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device) def test_llm_attr_with_skip_tokens(self) -> None: llm = DummyLLM() @@ -337,15 +335,15 @@ def test_llm_attr_with_skip_tokens(self) -> None: inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0]) res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0]) - # 5 output tokens, 4 input tokens including sos + # 5 output tokens, 3 input tokens including sos self.assertEqual(res.seq_attr.shape, (3,)) - assert res.token_attr is not None # make pyre/mypy happy + assert res.token_attr is not None self.assertIsNotNone(res.token_attr) token_attr = res.token_attr - self.assertEqual(token_attr.shape, (5, 3)) # type: ignore + self.assertEqual(token_attr.shape, (5, 3)) self.assertEqual(res.input_tokens, ["a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) self.assertEqual(res.seq_attr.device.type, self.device) - assert res.token_attr is not None # make pyre/mypy happy - self.assertEqual(token_attr.device.type, self.device) # type: ignore + assert res.token_attr is not None + self.assertEqual(token_attr.device.type, self.device)