-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Generate: handle cache_position update in generate
#29467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f5c91b9 to
572ca8e
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I think Llama is already testing this. Moving fast here
src/transformers/cache_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright, we are deprecating this anyways
src/transformers/generation/utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my single worry here is potential stride, adding a .contiguous() might be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've double-checked, it's always (1,) 🤗 (which makes sense, since it's a 1D tensor)
Its shape will indeed be different, at least between prefill and subsequent generation
src/transformers/generation/utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also set the dtype of the cache positions to int32 wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our integers inputs (input_ids, attention_mask, ...) are all int64, I think we should keep a consistent type :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have correct long typing here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see int64 comment above)
58660e2 to
10360b3
Compare
|
(rebasing and reruning tests, just in case 🙃 ) |
To resolve error `TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'` introduced by huggingface/transformers#29467
What does this PR do?
Updates
cache_positioningenerate, and makes it the primary source for the input position in the models that support them,llamaandgemma(as opposed to relying onpast_key_values.seen_tokens).The PR also adds the following related changes:
StaticCachenow supportsget_seq_length(). This was drawn from Static Cache: no mandatorycache_positionsinput #29221, and is needed for.prepare_inputs_for_generation()retrocompatibility;seen_tokensattribute enters a deprecation cycle, as it is redundant withcache_positions(and doesn't work with compilation).This PR is drawn from the diff in #29374, i.e. it is a requirement for
generatecompilation withfullgraph=True🙌👉 Llama, Gemma, and Cache slow tests ran, no new failures
👉 FWD compilation benchmarks ran, no throughput change