-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[Cache] Support compilable cache reuse with smaller batch sizes #37394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
| Cache for mamba model which does not have attention mechanism and key value states. | ||
| Cache for mamba model which does not have attention mechanism and key value states. At initialization, the cache | ||
| is preallocated to its maximum possible shape. Contrarily to other caches, `max_batch_size` must match the | ||
| batch size used at inference time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: adding the feature to MambaCache would imply adding model-level changes. And it wouldn't work with the fast kernels.
| self._cache.reset() | ||
| return self._cache | ||
|
|
||
| def _supports_default_dynamic_cache(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of the diff in this file is to make mamba + compile work again (MambaIntegrationTests::test_compile_mamba_cache was red)
In general, models with unique caches are messy to use with generate, and need some work. A model should be able to tell generate "hey, I can only use this cache class"
|
|
||
|
|
||
| @require_torch_accelerator | ||
| @slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all tests here are @slow. Tests that take >1s kept the decorator
| ("sdpa", "static"), | ||
| ] | ||
| ) | ||
| def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generation is not meant to work well with right-padding, no need to spend resources testing it
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for adding support on this and happy to see it can be done with minimal changes!
| "layer_device_map": layer_device_map, | ||
| } | ||
| cache_signature = inspect.signature(cache_cls.__init__) | ||
| cache_kwargs = {k: v for k, v in all_possible_cache_kwargs.items() if k in cache_signature.parameters} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we didn't change the signature, why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MambaCache + mamba was broken ☠️ This is needed to fix it
cc635a6 to
883ac39
Compare
|
(PR on hold: some slow cache tests are failing due to reasons unrelated to this PR, fixing them first before re-requesting a review) |
|
(caches have been refactored, better start from scratch) |
What does this PR do?
Supercedes #37389
Partially solves #35444
This PR makes our
max_cache_sizeargument in compilable caches finally true: we can now use a cache object with a batch size smaller than the one defined in the cache. Compile once and run with multiple input shapes -- particularly useful for export, as mentioned in #35444.Adds other minor related fixes (see PR comments).
We can see in the following test script that this does not degrade compiled performance: