-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[whisper] static kv cache #31166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[whisper] static kv cache #31166
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice overall.cc @zhenglongjiepheonix I reviewed this one instead of #30949 because it had less changes, sorry that work got duplicated here!
You can reference my PR #30949 for tests failing part, it passes all the tests that the current main branch passes and will save you a lot of time debugging @sanchit-gandhi |
Co-authored-by: Arthur Zucker <[email protected]>
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with the PR 🔥🔥 Let's goooo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would maybe just run the slow tests?
logger.warning_once( | ||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " | ||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. " | ||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💌
Thanks both for the reviews! Confirming that the slow tests pass on the DGX A100. Going to merge this one to enable static kv cache for:
We'll need a follow-up PR to enable:
|
and past_key_value is not None | ||
and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
): | ||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_shape
and _reshape
are not the same op, is it fine to replace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do the transpose later to get it into the original format: https://github.com/huggingface/transformers/pull/31166/files#r1652420859
But this is a good point - we don't need to _shape
then .transpose
the q-states, we can directly get them into the correct format
Hi, I am getting some cache errors while doing generation with llama3 and fsdp.
|
Hey @SaeedNajafi - do you have a minimal reproducer you could use to open a new issue on the repo? Thanks! |
The pipeline needs more work, specifically for longer audios + the merging solution. |
Thanks. I deleted the comment once I saw the PR already in progress #31772 for this exact thing. I think it's better to wait for the merge. |
Support the `cache_position` input that was added to Hugging Face Whisper models as part of a revision of how it handles KV-caching. This is like `position_ids`, but there is no batch dimension. See huggingface/optimum#1971 and huggingface/transformers#31166.
Support the `cache_position` input that was added to Hugging Face Whisper models as part of a revision of how it handles KV-caching. This is like `position_ids`, but there is no batch dimension. See huggingface/optimum#1971 and huggingface/transformers#31166.
It works with SDPA optimization. However, flash_attention_3 did not work with flash_attention_2 and kernels. code: from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import logging
import time
import librosa
audio_path = "test.mp3"
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3.5")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3.5", attn_implementation="kernels-community/flash-attn3")
model.to(torch_device, dtype=torch_dtype)
audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to(torch_device)
input_features = inputs.input_features.to(torch_dtype)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"
# compile
for i in range(2):
model.generate(input_features)
# inference
pred_ids = model.generate(input_features) Error Messages: File "/mnt/whisper-plus/.venv/lib/python3.10/site-packages/transformers/integrations/flash_attention.py", line 64, in flash_attention_forward
attn_output = _flash_attention_forward(
File "/mnt/whisper-plus/.venv/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 363, in _flash_attention_forward
if not all(k in globals() for k in ("_flash_fn", "_flash_varlen_fn", "_pad_fn", "_unpad_fn", "_is_fa3")):
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" |
What does this PR do?
Supersedes #28931 and extends it by adding static k/v cache support for Whisper. Also improves the performance of the eager attention implementation by removing un-necessary reshapes (inspired by LlamaAttention).
Similar to #28931, we use a separate cache for the self-attention and cross-attention layers. We define a lightweight
EncoderDecoderCache
wrapper that holds these two cache classes and implements common base methods (e.g.to_legacy_cache()
) by calling the corresponding methods for each cache class.However, there is one hurdle in enabling compatibility with
torch.compile
. Namely, we have to determine whether we're in the first decoding step, or second step onwards:=> the difficulty is in detecting whether we’re in the first decoding step (1), or second step onwards (2). With eager mode, we can condition on
past_key_values.get_seq_length()
to determine the decoding step. However, fortorch.compile
this introduces a graph break. Consequently, we add a boolean flagis_updated
to theStaticCache
class, which informs us whether the cache has been updated or not. The alternative would be to employ the same logic we do in the Flax code, where we re-compute the cross-attention k/v states each time. Benchmarks show this approach is 1.4x slower than adding the CPU flag.Using the
.generate
API with Whisper medium, we get approximately 5x speed-up when generating 64 tokens using sdpa attention. Note here that we compile the forward pass only:Extended results:
Whisper large-v3
Distil-Whisper distil-large-v3
As expected, the speed-ups for Distil-Whisper are less pronounced:
Code example:
In refactoring the eager attention implementation for the cache abstraction, I managed to remove a lot of wasteful
.view
operations, generally aligning it with LLaMA and giving a performance boost even without compile (TODO: quantify speed-up).The only regression comes when using FA2 and compile, where we have to introduce a bunch of new
.transpose
operations for compatibility with the shape of our k/v cache (TODO: quantify regression). This is also a known problem in LLaMA.There are a few tidy-up points left TODO. Once we're happy with the design, I'll complete the PR with the final checklist items:
past_key_values
,cache_position
)output_attentions=True