Skip to content

Static cache is locked after torch.compile with model.generate #30351

@mobicham

Description

@mobicham

System Info

  • transformers version: 4.39.0.dev0
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.4.1
  • Accelerate version: 0.21.0

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When using torch.compile(model.forward) with static cache, the cache seems to be locked with the first prompt that was used for the compilation time. I re-implemented the generate logic and the same issue happens, so it's not just a bug with model.generate. This happens with older and newer versions of transformers.

Here's a code snippet:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
	gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: OK
#  <s>  [INST] Write an essay about large language models [/INST]   Large language models have revolutionized the field of natural language processing in recent years. 
# These models are trained on vast amounts of text data and are capable of generating text, classifying text, and answering questions with remarkable accuracy. 
# In this essay, we will explore the current state of large language models, their potential applications, and the challenges and limitations that come with their use.....

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: WRONG still talks about the previous prompt.
# <s>  [INST] How to make a chocolate cake? [/INST]  ge language models (LLMs) are a class of artificial intelligence (AI) models that have gained significant 
#attention in recent years due to their impressive language processing capabilities. Here, we will explore the concept of LLMs, their applications, 
# and their potential impact on various fields.
# What are Large Language Models?
# LLMs are neural network-based models that are trained on vast amounts of text data to generate language outputs that are coherent and natural

Expected behavior

The output should correspond to the input prompt, not the prompt the model was first compiled with.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions