Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Apr 9, 2025

What does this PR do?

⚠️ this PR needs to be rebased, don't review/merge

Supercedes #37389
Partially solves #35444

This PR makes our max_cache_size argument in compilable caches finally true: we can now use a cache object with a batch size smaller than the one defined in the cache. Compile once and run with multiple input shapes -- particularly useful for export, as mentioned in #35444.

Adds other minor related fixes (see PR comments).


We can see in the following test script that this does not degrade compiled performance:

from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
import time

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", device_map="auto", torch_dtype=torch.float16)

input_ids = tokenizer(["The quick brown"], return_tensors="pt").input_ids.to(model.device)
cache_position = torch.arange(input_ids.shape[1]).to(model.device)

with torch.no_grad():
    #------------------------------------------------------------------------------------------------
    # OLD, cache batch size = input batch size
    # Measured on an RTX 4090: `main` = 0.223ms; this PR = 0.223ms
    cache = StaticCache(
        config=model.config,
        max_batch_size=1,
        max_cache_len=100,
        device=model.device,
        dtype=model.dtype
    )
    model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

    # warmup
    for _ in range(3):
        outputs = model(input_ids, cache_position=cache_position, past_key_values=cache)

    # measure
    start = time.time()
    for _ in range(100):
        outputs = model(input_ids, cache_position=cache_position, past_key_values=cache)
    end = time.time()
    print(f"[Old] Average time taken: {((end - start) / 100) * 1000} ms")

    #------------------------------------------------------------------------------------------------
    # clear torch compile cache
    torch._dynamo.reset()

    #------------------------------------------------------------------------------------------------
    # NEW, cache batch size > input batch size
    # Measured on an RTX 4090: `main` = Doesn't work; this PR = 0.224ms
    cache = StaticCache(
        config=model.config,
        max_batch_size=16,  # 16 >> 1
        max_cache_len=100,
        device=model.device,
        dtype=model.dtype
    )
    model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

    # warmup
    for _ in range(3):
        outputs = model(input_ids, cache_position=cache_position, past_key_values=cache)

    # measure
    start = time.time()
    for _ in range(100):
        outputs = model(input_ids, cache_position=cache_position, past_key_values=cache)
    end = time.time()
    print(f"[New] Average time taken: {((end - start) / 100) * 1000} ms")

@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@github-actions github-actions bot marked this pull request as draft April 9, 2025 16:35
@gante gante marked this pull request as ready for review April 9, 2025 16:35
Cache for mamba model which does not have attention mechanism and key value states.
Cache for mamba model which does not have attention mechanism and key value states. At initialization, the cache
is preallocated to its maximum possible shape. Contrarily to other caches, `max_batch_size` must match the
batch size used at inference time.
Copy link
Contributor Author

@gante gante Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: adding the feature to MambaCache would imply adding model-level changes. And it wouldn't work with the fast kernels.

self._cache.reset()
return self._cache

def _supports_default_dynamic_cache(self) -> bool:
Copy link
Contributor Author

@gante gante Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the diff in this file is to make mamba + compile work again (MambaIntegrationTests::test_compile_mamba_cache was red)

In general, models with unique caches are messy to use with generate, and need some work. A model should be able to tell generate "hey, I can only use this cache class"



@require_torch_accelerator
@slow
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all tests here are @slow. Tests that take >1s kept the decorator

("sdpa", "static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generation is not meant to work well with right-padding, no need to spend resources testing it

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for adding support on this and happy to see it can be done with minimal changes!

"layer_device_map": layer_device_map,
}
cache_signature = inspect.signature(cache_cls.__init__)
cache_kwargs = {k: v for k, v in all_possible_cache_kwargs.items() if k in cache_signature.parameters}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't change the signature, why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MambaCache + mamba was broken ☠️ This is needed to fix it

@gante gante force-pushed the cache_support_smaller_bs branch from cc635a6 to 883ac39 Compare April 17, 2025 13:15
@gante
Copy link
Contributor Author

gante commented Apr 22, 2025

(PR on hold: some slow cache tests are failing due to reasons unrelated to this PR, fixing them first before re-requesting a review)

@gante
Copy link
Contributor Author

gante commented Aug 12, 2025

(caches have been refactored, better start from scratch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants