Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict

from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
from typing import List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, Union

from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -33,3 +33,14 @@
TensorLikeList4D,
TensorLikeList5D,
]


class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""

def encode(
self, text: str, return_tensors: Optional[str] = None
) -> Union[List[int], Tensor]: ...
def decode(self, token_ids: Tensor) -> str: ...
def convert_ids_to_tokens(self, token_ids: Tensor) -> List[str]: ...
145 changes: 67 additions & 78 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# pyre-strict
from copy import copy

from typing import Any, Callable, cast, Dict, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np

import torch
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
from captum.attr._core.lime import Lime
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import Attribution
from captum.attr._utils.attribution import (
Attribution,
GradientAttribution,
PerturbationAttribution,
)
from captum.attr._utils.interpretable_input import (
InterpretableInput,
TextTemplateInput,
Expand Down Expand Up @@ -44,11 +49,12 @@ def __init__(
self.output_tokens = output_tokens

@property
def seq_attr_dict(self) -> Dict[str, Any]:
def seq_attr_dict(self) -> Dict[str, float]:
return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}

# pyre-fixme[3]: Return type must be annotated.
def plot_token_attr(self, show: bool = False):
def plot_token_attr(
self, show: bool = False
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
"""
Generate a matplotlib plot for visualising the attribution
of the output tokens.
Expand All @@ -58,7 +64,11 @@ def plot_token_attr(self, show: bool = False):
Default: False
"""

# pyre-fixme[16]: `Optional` has no attribute `cpu`.
if self.token_attr is None:
raise ValueError(
"token_attr is None (no token-level attribution was performed), please "
"use plot_seq_attr instead for the sequence-level attribution plot"
)
token_attr = self.token_attr.cpu() # type: ignore

# maximum absolute attribution value
Expand All @@ -83,7 +93,7 @@ def plot_token_attr(self, show: bool = False):
)

# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax) # type: ignore
cbar = fig.colorbar(im, ax=ax) # type: ignore
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")

# Show all ticks and label them with the respective list entries.
Expand Down Expand Up @@ -113,11 +123,13 @@ def plot_token_attr(self, show: bool = False):

if show:
plt.show()
return None # mypy wants this
else:
return fig, ax

# pyre-fixme[3]: Return type must be annotated.
def plot_seq_attr(self, show: bool = False):
def plot_seq_attr(
self, show: bool = False
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
"""
Generate a matplotlib plot for visualising the attribution
of the output sequence.
Expand Down Expand Up @@ -150,6 +162,7 @@ def plot_seq_attr(self, show: bool = False):

if show:
plt.show()
return None # mypy wants this
else:
return fig, ax

Expand Down Expand Up @@ -181,9 +194,8 @@ class LLMAttribution(Attribution):

def __init__(
self,
attr_method: Attribution,
# pyre-fixme[2]: Parameter must be annotated.
tokenizer,
attr_method: PerturbationAttribution,
tokenizer: TokenizerLike,
attr_target: str = "log_prob", # TODO: support callable attr_target
) -> None:
"""
Expand All @@ -208,24 +220,19 @@ class created with the llm model that follows huggingface style
super().__init__(attr_method.forward_func)

# shallow copy is enough to avoid modifying original instance
# pyre-fixme[4]: Attribute must be annotated.
self.attr_method = copy(attr_method)
# pyre-fixme[4]: Attribute must be annotated.
self.include_per_token_attr = isinstance(
self.attr_method: PerturbationAttribution = copy(attr_method)
self.include_per_token_attr: bool = isinstance(
attr_method, self.SUPPORTED_PER_TOKEN_ATTR_METHODS
)

self.attr_method.forward_func = self._forward_func

# alias, we really need a model and don't support wrapper functions
# coz we need call model.forward, model.generate, etc.
# pyre-fixme[4]: Attribute must be annotated.
self.model = cast(nn.Module, self.forward_func)
self.model: nn.Module = cast(nn.Module, self.forward_func)

# pyre-fixme[4]: Attribute must be annotated.
self.tokenizer = tokenizer
# pyre-fixme[4]: Attribute must be annotated.
self.device = (
self.tokenizer: TokenizerLike = tokenizer
self.device: torch.device = (
cast(torch.device, self.model.device)
if hasattr(self.model, "device")
else next(self.model.parameters()).device
Expand All @@ -239,15 +246,12 @@ class created with the llm model that follows huggingface style

def _forward_func(
self,
# pyre-fixme[2]: Parameter must be annotated.
perturbed_tensor,
# pyre-fixme[2]: Parameter must be annotated.
inp,
# pyre-fixme[2]: Parameter must be annotated.
target_tokens,
perturbed_tensor: Union[None, Tensor],
inp: InterpretableInput,
target_tokens: Tensor,
use_cached_outputs: bool = False,
_inspect_forward=None,
) -> Union[int, Tensor]:
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
) -> Tensor:
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
init_model_inp = perturbed_input

Expand Down Expand Up @@ -279,7 +283,9 @@ def _forward_func(
(model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1
)

total_log_prob = sum(log_prob_list)
# pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
# Tensor
total_log_prob: Tensor = sum(log_prob_list) # type: ignore
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
Expand All @@ -288,8 +294,6 @@ def _forward_func(
).unsqueeze(0)
else:
target_log_probs = total_log_prob # type: ignore
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int,
# Tensor]`.
target_probs = torch.exp(target_log_probs)

if _inspect_forward:
Expand All @@ -301,35 +305,31 @@ def _forward_func(

return target_probs if self.attr_target != "log_prob" else target_log_probs

# pyre-fixme[3]: Return type must be annotated.
def _format_model_input(self, model_input: Union[str, Tensor]):
def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
"""
Convert str to tokenized tensor
to make LLMAttribution work with model inputs of both
raw text and text token tensors
"""
# return tensor(1, n_tokens)
if isinstance(model_input, str):
return self.tokenizer.encode(model_input, return_tensors="pt").to(
self.device
)
# pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be
# Tensor
return self.tokenizer.encode( # type: ignore
model_input, return_tensors="pt"
).to(self.device)
return model_input.to(self.device)

def attribute(
self,
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
num_trials: int = 1,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
# errors.
gen_args: Optional[Dict] = None,
gen_args: Optional[Dict[str, Any]] = None,
use_cached_outputs: bool = True,
# internal callback hook can be used for logging
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
_inspect_forward: Optional[Callable] = None,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
**kwargs: Any,
) -> LLMAttributionResult:
"""
Args:
Expand Down Expand Up @@ -380,10 +380,14 @@ def attribute(
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target
else:
raise TypeError(
"target must either be str or Tensor, but the type of target is "
"{}".format(type(target))
)

attr = torch.zeros(
[
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
1 + len(target_tokens) if self.include_per_token_attr else 1,
inp.n_itp_features,
],
Expand All @@ -398,8 +402,6 @@ def attribute(
attr_input,
additional_forward_args=(
inp,
# pyre-fixme[61]: `target_tokens` is undefined, or not always
# defined.
target_tokens,
use_cached_outputs,
_inspect_forward,
Expand All @@ -424,7 +426,6 @@ def attribute(
attr[1:] if self.include_per_token_attr else None
), # shape(n_output_token, n_input_features)
inp.values,
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
self.tokenizer.convert_ids_to_tokens(target_tokens),
)

Expand Down Expand Up @@ -454,14 +455,11 @@ class LLMGradientAttribution(Attribution):
SUPPORTED_METHODS = (LayerIntegratedGradients,)
SUPPORTED_INPUTS = (TextTokenInput,)

# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
# pyre-fixme[2]: Parameter must be annotated.
attr_method,
# pyre-fixme[2]: Parameter must be annotated.
tokenizer,
):
attr_method: GradientAttribution,
tokenizer: TokenizerLike,
) -> None:
"""
Args:
attr_method (Attribution): instance of a supported perturbation attribution
Expand All @@ -476,19 +474,15 @@ class created with the llm model that follows huggingface style
super().__init__(attr_method.forward_func)

# shallow copy is enough to avoid modifying original instance
# pyre-fixme[4]: Attribute must be annotated.
self.attr_method = copy(attr_method)
self.attr_method: GradientAttribution = copy(attr_method)
self.attr_method.forward_func = self._forward_func

# alias, we really need a model and don't support wrapper functions
# coz we need call model.forward, model.generate, etc.
# pyre-fixme[4]: Attribute must be annotated.
self.model = cast(nn.Module, self.forward_func)
self.model: nn.Module = cast(nn.Module, self.forward_func)

# pyre-fixme[4]: Attribute must be annotated.
self.tokenizer = tokenizer
# pyre-fixme[4]: Attribute must be annotated.
self.device = (
self.tokenizer: TokenizerLike = tokenizer
self.device: torch.device = (
cast(torch.device, self.model.device)
if hasattr(self.model, "device")
else next(self.model.parameters()).device
Expand Down Expand Up @@ -526,9 +520,7 @@ def _forward_func(
# the attribution target is limited to the log probability
return token_log_probs

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _format_model_input(self, model_input):
def _format_model_input(self, model_input: Tensor) -> Tensor:
"""
Convert str to tokenized tensor
"""
Expand All @@ -538,12 +530,8 @@ def attribute(
self,
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
# errors.
gen_args: Optional[Dict] = None,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
gen_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMAttributionResult:
"""
Args:
Expand Down Expand Up @@ -590,19 +578,21 @@ def attribute(
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target
else:
raise TypeError(
"target must either be str or Tensor, but the type of target is "
"{}".format(type(target))
)

attr_inp = inp.to_tensor().to(self.device)

attr_list = []
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
for cur_target_idx, _ in enumerate(target_tokens):
# attr in shape(batch_size, input+output_len, emb_dim)
attr = self.attr_method.attribute(
attr_inp,
additional_forward_args=(
inp,
# pyre-fixme[61]: `target_tokens` is undefined, or not always
# defined.
target_tokens,
cur_target_idx,
),
Expand All @@ -629,7 +619,7 @@ def attribute(
# it attributes to all the elements of the output of the specified layer
# so we need special handling for the inp type which don't care all the elements
if isinstance(inp, TextTokenInput) and inp.itp_mask is not None:
itp_mask = inp.itp_mask.to(self.device)
itp_mask = inp.itp_mask.to(attr.device)
itp_mask = itp_mask.expand_as(attr)
attr = attr[itp_mask].view(attr.size(0), -1)

Expand All @@ -642,7 +632,6 @@ def attribute(
seq_attr,
attr, # shape(n_output_token, n_input_features)
inp.values,
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
self.tokenizer.convert_ids_to_tokens(target_tokens),
)

Expand Down
Loading