Skip to content

Commit 9e412b8

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Support Cache Class for New Versions of Transformers Library
Summary: Fixes D62210529 (now reverted by D62262760). Transformers library is now an optional dependency. We do not depend on it, however, we have some logic for `transformers` models here. The library will only be imported if a model already has the library in the corresponding environment. This TARGETS configuration prevents transformers version conflicts which e.g. caused T200877742. Add support for new transformers Cache objects. This may need changes in the future as it seems that LLMs handle Caching differently. Some handle Caching themselves, however, some of them do not and some of them don't support Caches yet. Llama models seem to have a `_supports_cache_class` flag that indicates whether this new Cache object is supported. If it isn't marked as supported, we assume it takes legacy format (tuple past values). Multiple checks added to ensure compatibility. (minor) Also, changed the defaults for LLM generation to dismiss warnings (does not change generation behavior). Differential Revision: D62408520
1 parent d8ceaa8 commit 9e412b8

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env python3
2+
3+
# pyre-strict
4+
5+
try:
6+
# pyre-ignore[21]: Could not find a module corresponding to import
7+
# `transformers.cache_utils`
8+
from transformers.cache_utils import Cache, DynamicCache
9+
except ImportError:
10+
Cache = DynamicCache = None

captum/attr/_core/llm_attr.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
import torch
10+
from captum._utils.transformers_typing import Cache, DynamicCache
1011
from captum._utils.typing import TokenizerLike
1112
from captum.attr._core.feature_ablation import FeatureAblation
1213
from captum.attr._core.kernel_shap import KernelShap
@@ -27,8 +28,12 @@
2728
)
2829
from torch import nn, Tensor
2930

30-
31-
DEFAULT_GEN_ARGS = {"max_new_tokens": 25, "do_sample": False}
31+
DEFAULT_GEN_ARGS: Dict[str, Any] = {
32+
"max_new_tokens": 25,
33+
"do_sample": False,
34+
"temperature": None,
35+
"top_p": None,
36+
}
3237

3338

3439
class LLMAttributionResult:
@@ -258,15 +263,24 @@ def _forward_func(
258263
init_model_inp = perturbed_input
259264

260265
model_inp = init_model_inp
261-
attention_mask = torch.tensor([[1] * model_inp.shape[1]])
262-
attention_mask = attention_mask.to(model_inp.device)
266+
attention_mask = torch.ones(
267+
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
268+
)
263269
model_kwargs = {"attention_mask": attention_mask}
264270

265271
log_prob_list = []
266272
outputs = None
267273
for target_token in target_tokens:
268274
if use_cached_outputs:
269275
if outputs is not None:
276+
if (
277+
Cache is not None
278+
and getattr(self.model, "_supports_cache_class", False)
279+
and not isinstance(outputs.past_key_values, Cache)
280+
):
281+
outputs.past_key_values = DynamicCache.from_legacy_cache(
282+
outputs.past_key_values
283+
)
270284
model_kwargs = self.model._update_model_kwargs_for_generation(
271285
outputs, model_kwargs
272286
)
@@ -275,7 +289,7 @@ def _forward_func(
275289
)
276290
outputs = self.model.forward(**model_inputs)
277291
else:
278-
outputs = self.model.forward(model_inp, attention_mask=attention_mask)
292+
outputs = self.model.forward(model_inp, **model_kwargs)
279293
new_token_logits = outputs.logits[:, -1]
280294
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
281295

@@ -345,7 +359,8 @@ def attribute(
345359
Defaults: 1.
346360
gen_args (dict, optional): arguments for generating the target. Only used if
347361
target is not given. When None, the default arguments are used,
348-
{"max_length": 25, "do_sample": False}
362+
{"max_new_tokens": 25, "do_sample": False,
363+
"temperature": None, "top_p": None}
349364
Defaults: None
350365
**kwargs (Any): any extra keyword arguments passed to the call of the
351366
underlying attribute function of the given attribution instance
@@ -516,7 +531,8 @@ def attribute(
516531
Default: None
517532
gen_args (dict, optional): arguments for generating the target. Only used if
518533
target is not given. When None, the default arguments are used,
519-
{"max_length": 25, "do_sample": False}
534+
{"max_new_tokens": 25, "do_sample": False,
535+
"temperature": None, "top_p": None}
520536
Defaults: None
521537
**kwargs (Any): any extra keyword arguments passed to the call of the
522538
underlying attribute function of the given attribution instance

0 commit comments

Comments
 (0)