Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
Gemma3ForCausalLM,
)

from fastchat.constants import CPU_ISA
Expand All @@ -36,6 +37,7 @@
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.model_xfastertransformer import generate_stream_xft
from fastchat.model.model_cllm import generate_stream_cllm
from fastchat.model.model_gemma3 import generate_stream_gemma3

from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
Expand Down Expand Up @@ -253,7 +255,12 @@ def load_model(
kwargs = {"torch_dtype": torch.float16}
import transformers

version = tuple(int(v) for v in transformers.__version__.split("."))
try:
version = tuple(int(v) for v in transformers.__version__.split("."))
except ValueError:
# some versions of transformers have a different version format (
# e.g. 4.50.0.dev0) and these break this parser so we set a default
version = (4, 36, 0)
if version < (4, 35, 0):
# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
Expand Down Expand Up @@ -414,6 +421,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_xft = "xft" in model_type
is_yuan = "yuan" in model_type
is_cllm = "consistency-llm" in model_path.lower()
is_gemma3 = "gemma-3" in model_path.lower()

if is_chatglm:
return generate_stream_chatglm
Expand All @@ -429,6 +437,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
return generate_stream_yuan2
elif is_cllm:
return generate_stream_cllm
elif is_gemma3:
return generate_stream_gemma3

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
Expand All @@ -453,6 +463,7 @@ def generate_stream_peft(
is_xft = "xft" in base_model_type
is_yuan = "yuan" in base_model_type
is_cllm = "consistency-llm" in model_path.lower()
is_gemma3 = "gemma-3" in model_path.lower()

generate_stream_function = generate_stream
if is_chatglm:
Expand All @@ -469,6 +480,8 @@ def generate_stream_peft(
generate_stream_function = generate_stream_yuan2
elif is_cllm:
generate_stream_function = generate_stream_cllm
elif is_gemma3:
generate_stream_function = generate_stream_gemma3
for x in generate_stream_function(
model,
tokenizer,
Expand Down Expand Up @@ -818,6 +831,31 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class Gemma3Adapter(BaseModelAdapter):
"""The model adapter for google/gemma-3"""

def match(self, model_path: str):
return "gemma-3" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
device_map = from_pretrained_kwargs.get("device_map", None)
if device_map == "sequential":
device_map = "auto"
# print("From pretrained kwargs", from_pretrained_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I have a small suggestion:

Suggested change
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision, pad_to_multiple_of=8)

See this similar issue in huggingface/transformers: huggingface/transformers#36815

Some prompts may trigger an error similar to the following:

ERROR | stderr | Exception in thread Thread-5 (<lambda>):
ERROR | stderr | Traceback (most recent call last):
ERROR | stderr |   File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
ERROR | stderr |     self.run()
ERROR | stderr |   File "/usr/lib/python3.10/threading.py", line 953, in run
ERROR | stderr |     self._target(*self._args, **self._kwargs)
ERROR | stderr |   File "/home/example/projects/FastChat/fastchat/model/model_gemma3.py", line 81, in <lambda>
ERROR | stderr |     target=lambda: model.generate(input_ids=input_ids, **generate_kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR | stderr |     return func(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2465, in generate
ERROR | stderr |     result = self._sample(
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 3434, in _sample
ERROR | stderr |     outputs = model_forward(**model_inputs, return_dict=True)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
ERROR | stderr |     return self._call_impl(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
ERROR | stderr |     return forward_call(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
ERROR | stderr |     output = func(self, *args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
ERROR | stderr |     return func(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 942, in forward
ERROR | stderr |     outputs: BaseModelOutputWithPast = self.model(
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
ERROR | stderr |     return self._call_impl(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
ERROR | stderr |     return forward_call(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
ERROR | stderr |     output = func(self, *args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 722, in forward
ERROR | stderr |     layer_outputs = decoder_layer(
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
ERROR | stderr |     return self._call_impl(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
ERROR | stderr |     return forward_call(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 420, in forward
ERROR | stderr |     hidden_states, self_attn_weights = self.self_attn(
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
ERROR | stderr |     return self._call_impl(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
ERROR | stderr |     return forward_call(*args, **kwargs)
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 342, in forward
ERROR | stderr |     attn_output, attn_weights = attention_interface(
ERROR | stderr |   File "/home/example/projects/fastchat-venv/lib/python3.10/site-packages/transformers/integrations/sdpa_attention.py", line 54, in sdpa_attention_forward
ERROR | stderr |     attn_output = torch.nn.functional.scaled_dot_product_attention(
ERROR | stderr | RuntimeError: p.attn_bias_ptr is not correctly aligned

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Thanks for this, we actually ended up creating: https://www.github.com/transformerlab/transformerlab-inference.
We use that instead since fastchat hasn't been merging and stopped new developments.
This model is added on there and works without flash attention which was causing your original issue, please let me know if it also occuses without flash attention too?

model = Gemma3ForCausalLM.from_pretrained(
model_path,
revision=revision,
torch_dtype=torch.bfloat16,
device_map=device_map,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("gemma")


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for Koala"""

Expand Down Expand Up @@ -2502,6 +2540,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:

# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(Gemma3Adapter)
register_model_adapter(PeftModelAdapter)
register_model_adapter(StableVicunaAdapter)
register_model_adapter(VicunaAdapter)
Expand Down
145 changes: 145 additions & 0 deletions fastchat/model/model_gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from threading import Thread
import gc
import torch
from transformers import TextIteratorStreamer


def generate_stream_gemma3(
model,
tokenizer,
params,
device,
context_len,
stream_interval=2,
judge_sent_end=False,
):
"""Custom generate stream function for Gemma-3 models"""
# Get parameters from the request
prompt = params.get("prompt", "")
messages = params.get("messages", None)
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
model_name = params.get("model", None)

if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)

is_base_model = "pt" in model_name.lower() or "base" in model_name.lower()

if not is_base_model:
# Format input based on whether we have messages or a plain prompt
if messages:
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
else:
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
else:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

input_ids = inputs["input_ids"]
input_echo_len = input_ids.shape[1]

# Configure generation parameters
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0.0,
"temperature": temperature if temperature > 0.0 else 1.0,
}

if top_p < 1.0:
generate_kwargs["top_p"] = top_p
if top_k > 0:
generate_kwargs["top_k"] = top_k
if repetition_penalty > 1.0:
generate_kwargs["repetition_penalty"] = repetition_penalty

streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
)
generate_kwargs["streamer"] = streamer

# Start generation in a separate thread
thread = Thread(
target=lambda: model.generate(input_ids=input_ids, **generate_kwargs)
)
thread.start()

# Track generation progress
generated_tokens = 0
output_text = ""

# Stream tokens
for new_text in streamer:
output_text += new_text
generated_tokens += 1

# Check for stop strings
should_stop = False
if stop_str:
if isinstance(stop_str, str):
if stop_str in output_text:
output_text = output_text[: output_text.find(stop_str)]
should_stop = True
elif isinstance(stop_str, list):
for stop in stop_str:
if stop in output_text:
output_text = output_text[: output_text.find(stop)]
should_stop = True
break

# Stream at intervals or when stopping
if generated_tokens % stream_interval == 0 or should_stop:
yield {
"text": output_text,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": generated_tokens,
"total_tokens": input_echo_len + generated_tokens,
},
"finish_reason": "stop" if should_stop else None,
}

if should_stop:
break

# Final output with finish reason
if thread.is_alive():
thread.join(
timeout=3600
) # Arbitrary value, but if it doesn't complete in this much time then something is wrong

yield {
"text": output_text,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": generated_tokens,
"total_tokens": input_echo_len + generated_tokens,
},
"finish_reason": "length",
}

# Clean up
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()