Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions captum/_utils/transformers_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3

# pyre-strict

from typing import Optional, Protocol, Tuple, Type

import torch


class CacheLike(Protocol):
"""Protocol for cache-like objects."""


class DynamicCacheLike(CacheLike, Protocol):
"""Protocol for dynamic cache-like objects."""

@classmethod
def from_legacy_cache(
cls: Type["DynamicCacheLike"],
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
) -> "DynamicCacheLike": ...


try:
# pyre-ignore[21]: Could not find a module corresponding to import
# `transformers.cache_utils`
from transformers.cache_utils import Cache as _Cache, DynamicCache as _DynamicCache
except ImportError:
_Cache = _DynamicCache = None

Cache: Optional[Type[CacheLike]] = _Cache
DynamicCache: Optional[Type[DynamicCacheLike]] = _DynamicCache
31 changes: 24 additions & 7 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

import torch
from captum._utils.transformers_typing import Cache, DynamicCache
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
Expand All @@ -27,8 +28,12 @@
)
from torch import nn, Tensor


DEFAULT_GEN_ARGS = {"max_new_tokens": 25, "do_sample": False}
DEFAULT_GEN_ARGS: Dict[str, Any] = {
"max_new_tokens": 25,
"do_sample": False,
"temperature": None,
"top_p": None,
}


class LLMAttributionResult:
Expand Down Expand Up @@ -258,15 +263,25 @@ def _forward_func(
init_model_inp = perturbed_input

model_inp = init_model_inp
attention_mask = torch.tensor([[1] * model_inp.shape[1]])
attention_mask = attention_mask.to(model_inp.device)
attention_mask = torch.ones(
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
)
model_kwargs = {"attention_mask": attention_mask}

log_prob_list = []
outputs = None
for target_token in target_tokens:
if use_cached_outputs:
if outputs is not None:
if (
Cache is not None
and DynamicCache is not None
and getattr(self.model, "_supports_cache_class", False)
and not isinstance(outputs.past_key_values, Cache)
):
outputs.past_key_values = DynamicCache.from_legacy_cache(
outputs.past_key_values
)
model_kwargs = self.model._update_model_kwargs_for_generation(
outputs, model_kwargs
)
Expand All @@ -275,7 +290,7 @@ def _forward_func(
)
outputs = self.model.forward(**model_inputs)
else:
outputs = self.model.forward(model_inp, attention_mask=attention_mask)
outputs = self.model.forward(model_inp, **model_kwargs)
new_token_logits = outputs.logits[:, -1]
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)

Expand Down Expand Up @@ -345,7 +360,8 @@ def attribute(
Defaults: 1.
gen_args (dict, optional): arguments for generating the target. Only used if
target is not given. When None, the default arguments are used,
{"max_length": 25, "do_sample": False}
{"max_new_tokens": 25, "do_sample": False,
"temperature": None, "top_p": None}
Defaults: None
**kwargs (Any): any extra keyword arguments passed to the call of the
underlying attribute function of the given attribution instance
Expand Down Expand Up @@ -516,7 +532,8 @@ def attribute(
Default: None
gen_args (dict, optional): arguments for generating the target. Only used if
target is not given. When None, the default arguments are used,
{"max_length": 25, "do_sample": False}
{"max_new_tokens": 25, "do_sample": False,
"temperature": None, "top_p": None}
Defaults: None
**kwargs (Any): any extra keyword arguments passed to the call of the
underlying attribute function of the given attribution instance
Expand Down
Loading