Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,21 @@ def update(
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)

k_out = self.keys
v_out = self.values
batch_size = key_states.shape[0]
if k_out.shape[0] != batch_size:
Copy link
Contributor

@mobicham mobicham Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] >= batch_size is better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When debugging the torch.compile stuff, can you check this:

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
assert v_out.data_ptr() == v_out[:batch_size].data_ptr() , "invalid v_out data copy()!"

If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] <= batch_size is better

Wait should it be < or > considering k_out will be larger than the batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

Builtin `operator.*` comparison with constant `self` failed
  Explanation: Failed to compare DataPtrVariable() with DataPtrVariable(), because DataPtrVariable() is not a Python constant or its mutation check fails.

About the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with

Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(s72, 6, 1, 64), dtype=torch.float16,
           grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>)), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 1.0, 'is_causal': False}): got RuntimeError('Attempting to broadcast a dimension of length 8 at -2! Mismatching argument at index 1 had [8, 6]; but expected shape should be broadcastable to [s72, 6]')

from user code:
   File "/home/vedth/stuhdy/z.py", line 21, in decoder_forward
    out = model.model.decoder(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 865, in forward
    layer_outputs = decoder_layer(
  File "/home/vedth/stuhdy/transformers/src/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 501, in forward
    hidden_states, cross_attn_weights = self.encoder_attn(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 347, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/vedth/stuhdy/transformers/src/transformers/integrations/sdpa_attention.py", line 92, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait should it be < or > considering k_out will be larger than the batch_size

Oh sorry you're right, I meant current_batch_size <= max_batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

I meant run it without torch.compile, just to see if it performs any copy

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace

I see, I will try to debug next week too 👍

k_out = k_out[:batch_size]
v_out = v_out[:batch_size]
Comment on lines +336 to +339
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you run benchmarks and check if it creates excessive cudagraph breaks, as per the last comment from mobicham. In any case, a small benchmark run will be needed before merging the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified the fix with GPT2 using torch.compile(mode='reduce-overhead', fullgraph=True) . For some reason it keeps failing with whisper models and I can't really figure out why .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the bench script to PR description pls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I'll add it

# Update the cache
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
Expand Down