Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def plot_token_attr(
"token_attr is None (no token-level attribution was performed), please "
"use plot_seq_attr instead for the sequence-level attribution plot"
)
token_attr = self.token_attr.cpu() # type: ignore
token_attr = self.token_attr.cpu()

# maximum absolute attribution value
# used as the boundary of normalization
Expand Down Expand Up @@ -343,7 +343,7 @@ def _forward_func(
caching=use_cached_outputs,
)

log_prob_list = []
log_prob_list: List[Tensor] = []
outputs = None
for target_token in target_tokens:
if use_cached_outputs:
Expand Down Expand Up @@ -382,17 +382,15 @@ def _forward_func(
(model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1
)

# pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
# Tensor
total_log_prob: Tensor = sum(log_prob_list) # type: ignore
total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0)
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0 # type: ignore
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
else:
target_log_probs = total_log_prob # type: ignore
target_log_probs = total_log_prob
target_probs = torch.exp(target_log_probs)

if _inspect_forward:
Expand All @@ -412,9 +410,9 @@ def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
"""
# return tensor(1, n_tokens)
if isinstance(model_input, str):
return self.tokenizer.encode( # type: ignore
model_input, return_tensors="pt"
).to(self.device)
return self.tokenizer.encode(model_input, return_tensors="pt").to(
self.device
)
return model_input.to(self.device)

def attribute(
Expand Down Expand Up @@ -544,8 +542,7 @@ def attribute(
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
def attribute_future(self) -> Callable:
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
r"""
This method is not implemented for LLMAttribution.
"""
Expand Down Expand Up @@ -612,9 +609,9 @@ def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor:
Convert str to tokenized tensor
"""
if isinstance(model_input, str):
return self.tokenizer.encode( # type: ignore
model_input, return_tensors="pt"
).to(self.device)
return self.tokenizer.encode(model_input, return_tensors="pt").to(
self.device
)
return model_input.to(self.device)

def attribute(
Expand Down Expand Up @@ -745,8 +742,7 @@ def attribute(
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
def attribute_future(self) -> Callable:
def attribute_future(self) -> Callable[[], LLMAttributionResult]:
r"""
This method is not implemented for LLMGradientAttribution.
"""
Expand Down
51 changes: 22 additions & 29 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
cast,
Dict,
List,
Literal,
NamedTuple,
Optional,
overload,
Expand All @@ -18,7 +19,6 @@

import torch
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.typing import Literal
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
Expand All @@ -44,9 +44,6 @@ class DummyTokenizer:
@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
Expand Down Expand Up @@ -393,9 +390,6 @@ def test_llm_attr_without_token(
"m n o p q",
skip_tokens=[0],
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)

Expand Down Expand Up @@ -439,10 +433,10 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
self.assertEqual(token_attr.shape, (6, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])

Expand All @@ -462,18 +456,17 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
self.assertEqual(token_attr.shape, (5, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])


@parameterized_class(
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
)
# pyre-fixme[13]: Attribute `device` is never initialized.
class TestLLMGradAttr(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
Expand Down Expand Up @@ -505,16 +498,16 @@ def test_llm_attr(

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
self.assertEqual(token_attr.shape, (5, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)

@parameterized.expand(
[
Expand Down Expand Up @@ -542,16 +535,16 @@ def test_llm_attr_without_target(
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws)

self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (3, 4)) # type: ignore
self.assertEqual(token_attr.shape, (3, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["x", "y", "z"])

self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)

@parameterized.expand(
[
Expand Down Expand Up @@ -580,16 +573,16 @@ def test_llm_attr_with_skip_tokens(

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (3,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 3)) # type: ignore
self.assertEqual(token_attr.shape, (5, 3))
self.assertEqual(res.input_tokens, ["a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)

def test_llm_attr_with_no_skip_tokens(self) -> None:
llm = DummyLLM()
Expand All @@ -602,12 +595,12 @@ def test_llm_attr_with_no_skip_tokens(self) -> None:
inp = TextTokenInput("a b c", tokenizer)
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)

# 5 output tokens, 4 input tokens including sos
# 6 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (6, 4)) # type: ignore
self.assertEqual(token_attr.shape, (6, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["<sos>", "m", "n", "o", "p", "q"])

Expand All @@ -629,9 +622,9 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:

# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
self.assertEqual(token_attr.shape, (5, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
54 changes: 26 additions & 28 deletions tests/attr/test_llm_attr_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@
# pyre-strict

import copy
from typing import Any, cast, Dict, List, NamedTuple, Optional, overload, Type, Union
from typing import (
Any,
cast,
Dict,
List,
Literal,
NamedTuple,
Optional,
overload,
Type,
Union,
)

import torch

from captum._utils.typing import Literal
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
Expand All @@ -32,9 +42,6 @@ class DummyTokenizer:
@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
Expand Down Expand Up @@ -122,9 +129,6 @@ def generate(
assert mock_response, "must mock response to use DummyLLM to geenrate"
response = self.tokenizer.encode(mock_response)[1:]
return torch.cat(
# pyre-fixme[6]: In call `torch._C._VariableFunctions.cat`,
# for 1st positional argument, expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` but got `List[Union[List[int], Tensor]]`.
[input_ids, torch.tensor([response], device=self.device)], # type: ignore
dim=1,
)
Expand Down Expand Up @@ -178,10 +182,6 @@ def device(self) -> torch._C.device:
else [("cpu", True), ("cpu", False)]
),
)
# pyre-fixme[13]: Attribute `device` is declared in class `TestLlmAttrGpu`
# to have type `str` but is never initialized.
# pyre-fixme[13]: Attribute `use_cached_outputs` is declared in class `TestLlmAttrGpu`
# to have type `bool` but is never initialized.
class TestLlmAttrGpu(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
Expand Down Expand Up @@ -277,8 +277,6 @@ def test_llm_attr_without_token_gpu(
@parameterized_class(
("device",), [("cuda",)] if torch.cuda.is_available() else [("cpu",)]
)
# pyre-fixme[13]: Attribute `device` is declared in class `TestLLMGradAttrGPU`
# to have type `str` but is never initialized.
class TestLLMGradAttrGPU(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
Expand All @@ -294,16 +292,16 @@ def test_llm_attr(self) -> None:
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])
# 5 output tokens, 4 input tokens including sos
self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 4)) # type: ignore
self.assertEqual(token_attr.shape, (5, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)

def test_llm_attr_without_target(self) -> None:
llm = DummyLLM()
Expand All @@ -316,16 +314,16 @@ def test_llm_attr_without_target(self) -> None:
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"})

self.assertEqual(res.seq_attr.shape, (4,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (3, 4)) # type: ignore
self.assertEqual(token_attr.shape, (3, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["x", "y", "z"])

self.assertEqual(res.seq_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)

def test_llm_attr_with_skip_tokens(self) -> None:
llm = DummyLLM()
Expand All @@ -337,15 +335,15 @@ def test_llm_attr_with_skip_tokens(self) -> None:
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
res = llm_attr.attribute(inp, "m n o p q", skip_tokens=[0])

# 5 output tokens, 4 input tokens including sos
# 5 output tokens, 3 input tokens including sos
self.assertEqual(res.seq_attr.shape, (3,))
assert res.token_attr is not None # make pyre/mypy happy
assert res.token_attr is not None
self.assertIsNotNone(res.token_attr)
token_attr = res.token_attr
self.assertEqual(token_attr.shape, (5, 3)) # type: ignore
self.assertEqual(token_attr.shape, (5, 3))
self.assertEqual(res.input_tokens, ["a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
assert res.token_attr is not None # make pyre/mypy happy
self.assertEqual(token_attr.device.type, self.device) # type: ignore
assert res.token_attr is not None
self.assertEqual(token_attr.device.type, self.device)
Loading