Skip to content

Commit 8b93cf0

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Support Cache Class for New Versions of Transformers Library (#1341)
Summary: Pull Request resolved: #1341 Fixes D62210529 (now reverted by D62262760). Transformers library is now an optional dependency. We do not depend on it, however, we have some logic for `transformers` models here. The library will only be imported if a model already has the library in the corresponding environment. This TARGETS configuration prevents transformers version conflicts which e.g. caused T200877742. Add support for new transformers Cache objects. This may need changes in the future as it seems that LLMs handle Caching differently. Some handle Caching themselves, however, some of them do not and some of them don't support Caches yet. Llama models seem to have a `_supports_cache_class` flag that indicates whether this new Cache object is supported. If it isn't marked as supported, we assume it takes legacy format (tuple past values). Multiple checks added to ensure compatibility. (minor) Also, changed the defaults for LLM generation to dismiss warnings (does not change generation behavior). Differential Revision: D62408520
1 parent d8ceaa8 commit 8b93cf0

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python3
2+
3+
# pyre-strict
4+
5+
from typing import Optional, Protocol, Tuple, Type
6+
7+
import torch
8+
9+
10+
class CacheLike(Protocol): ...
11+
12+
13+
class DynamicCacheLike(CacheLike):
14+
@classmethod
15+
def from_legacy_cache(
16+
cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None
17+
) -> "DynamicCacheLike": ...
18+
19+
20+
Cache: Optional[Type[CacheLike]]
21+
DynamicCache: Optional[Type[DynamicCacheLike]]
22+
23+
try:
24+
# pyre-ignore[21]: Could not find a module corresponding to import
25+
# `transformers.cache_utils`
26+
from transformers.cache_utils import Cache, DynamicCache
27+
except ImportError:
28+
Cache = DynamicCache = None

captum/attr/_core/llm_attr.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
import torch
10+
from captum._utils.transformers_typing import Cache, DynamicCache
1011
from captum._utils.typing import TokenizerLike
1112
from captum.attr._core.feature_ablation import FeatureAblation
1213
from captum.attr._core.kernel_shap import KernelShap
@@ -27,8 +28,12 @@
2728
)
2829
from torch import nn, Tensor
2930

30-
31-
DEFAULT_GEN_ARGS = {"max_new_tokens": 25, "do_sample": False}
31+
DEFAULT_GEN_ARGS: Dict[str, Any] = {
32+
"max_new_tokens": 25,
33+
"do_sample": False,
34+
"temperature": None,
35+
"top_p": None,
36+
}
3237

3338

3439
class LLMAttributionResult:
@@ -258,15 +263,25 @@ def _forward_func(
258263
init_model_inp = perturbed_input
259264

260265
model_inp = init_model_inp
261-
attention_mask = torch.tensor([[1] * model_inp.shape[1]])
262-
attention_mask = attention_mask.to(model_inp.device)
266+
attention_mask = torch.ones(
267+
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
268+
)
263269
model_kwargs = {"attention_mask": attention_mask}
264270

265271
log_prob_list = []
266272
outputs = None
267273
for target_token in target_tokens:
268274
if use_cached_outputs:
269275
if outputs is not None:
276+
if (
277+
Cache is not None
278+
and DynamicCache is not None
279+
and getattr(self.model, "_supports_cache_class", False)
280+
and not isinstance(outputs.past_key_values, Cache)
281+
):
282+
outputs.past_key_values = DynamicCache.from_legacy_cache(
283+
outputs.past_key_values
284+
)
270285
model_kwargs = self.model._update_model_kwargs_for_generation(
271286
outputs, model_kwargs
272287
)
@@ -275,7 +290,7 @@ def _forward_func(
275290
)
276291
outputs = self.model.forward(**model_inputs)
277292
else:
278-
outputs = self.model.forward(model_inp, attention_mask=attention_mask)
293+
outputs = self.model.forward(model_inp, **model_kwargs)
279294
new_token_logits = outputs.logits[:, -1]
280295
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
281296

@@ -345,7 +360,8 @@ def attribute(
345360
Defaults: 1.
346361
gen_args (dict, optional): arguments for generating the target. Only used if
347362
target is not given. When None, the default arguments are used,
348-
{"max_length": 25, "do_sample": False}
363+
{"max_new_tokens": 25, "do_sample": False,
364+
"temperature": None, "top_p": None}
349365
Defaults: None
350366
**kwargs (Any): any extra keyword arguments passed to the call of the
351367
underlying attribute function of the given attribution instance
@@ -516,7 +532,8 @@ def attribute(
516532
Default: None
517533
gen_args (dict, optional): arguments for generating the target. Only used if
518534
target is not given. When None, the default arguments are used,
519-
{"max_length": 25, "do_sample": False}
535+
{"max_new_tokens": 25, "do_sample": False,
536+
"temperature": None, "top_p": None}
520537
Defaults: None
521538
**kwargs (Any): any extra keyword arguments passed to the call of the
522539
underlying attribute function of the given attribution instance

0 commit comments

Comments
 (0)