Skip to content

Conversation

@gante
Copy link
Member

@gante gante commented Mar 5, 2024

What does this PR do?

Updates cache_position in generate, and makes it the primary source for the input position in the models that support them, llama and gemma (as opposed to relying on past_key_values.seen_tokens).

The PR also adds the following related changes:

  1. StaticCache now supports get_seq_length(). This was drawn from Static Cache: no mandatory cache_positions input #29221, and is needed for .prepare_inputs_for_generation() retrocompatibility;
  2. The seen_tokens attribute enters a deprecation cycle, as it is redundant with cache_positions (and doesn't work with compilation).

This PR is drawn from the diff in #29374, i.e. it is a requirement for generate compilation with fullgraph=True 🙌

👉 Llama, Gemma, and Cache slow tests ran, no new failures
👉 FWD compilation benchmarks ran, no throughput change

@gante gante force-pushed the update_cache_position branch from f5c91b9 to 572ca8e Compare March 6, 2024 15:01
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante marked this pull request as ready for review March 6, 2024 16:19
@gante gante requested a review from ArthurZucker March 6, 2024 16:19
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think Llama is already testing this. Moving fast here

Comment on lines +418 to +420
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, we are deprecating this anyways

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my single worry here is potential stride, adding a .contiguous() might be needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've double-checked, it's always (1,) 🤗 (which makes sense, since it's a 1D tensor)

Its shape will indeed be different, at least between prefill and subsequent generation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also set the dtype of the cache positions to int32 wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our integers inputs (input_ids, attention_mask, ...) are all int64, I think we should keep a consistent type :p

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have correct long typing here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see int64 comment above)

@gante gante force-pushed the update_cache_position branch from 58660e2 to 10360b3 Compare March 14, 2024 16:09
@gante
Copy link
Member Author

gante commented Mar 14, 2024

(rebasing and reruning tests, just in case 🙃 )

@gante gante merged commit 23db187 into huggingface:main Mar 14, 2024
@gante gante deleted the update_cache_position branch March 14, 2024 16:35
itsdotscience added a commit to itsdotscience/LLaVA that referenced this pull request Mar 22, 2024
To resolve error `TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'` introduced by 

huggingface/transformers#29467
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants