diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 16cf5d2b6..69db01aa6 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -23,6 +23,7 @@ LlamaTokenizer, LlamaForCausalLM, T5Tokenizer, + Gemma3ForCausalLM, ) from fastchat.constants import CPU_ISA @@ -36,6 +37,7 @@ from fastchat.model.model_exllama import generate_stream_exllama from fastchat.model.model_xfastertransformer import generate_stream_xft from fastchat.model.model_cllm import generate_stream_cllm +from fastchat.model.model_gemma3 import generate_stream_gemma3 from fastchat.model.monkey_patch_non_inplace import ( replace_llama_attn_with_non_inplace_operations, @@ -253,7 +255,12 @@ def load_model( kwargs = {"torch_dtype": torch.float16} import transformers - version = tuple(int(v) for v in transformers.__version__.split(".")) + try: + version = tuple(int(v) for v in transformers.__version__.split(".")) + except ValueError: + # some versions of transformers have a different version format ( + # e.g. 4.50.0.dev0) and these break this parser so we set a default + version = (4, 36, 0) if version < (4, 35, 0): # NOTE: Recent transformers library seems to fix the mps issue, also # it has made some changes causing compatibility issues with our @@ -414,6 +421,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): is_xft = "xft" in model_type is_yuan = "yuan" in model_type is_cllm = "consistency-llm" in model_path.lower() + is_gemma3 = "gemma-3" in model_path.lower() if is_chatglm: return generate_stream_chatglm @@ -429,6 +437,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): return generate_stream_yuan2 elif is_cllm: return generate_stream_cllm + elif is_gemma3: + return generate_stream_gemma3 elif peft_share_base_weights and is_peft: # Return a curried stream function that loads the right adapter @@ -453,6 +463,7 @@ def generate_stream_peft( is_xft = "xft" in base_model_type is_yuan = "yuan" in base_model_type is_cllm = "consistency-llm" in model_path.lower() + is_gemma3 = "gemma-3" in model_path.lower() generate_stream_function = generate_stream if is_chatglm: @@ -469,6 +480,8 @@ def generate_stream_peft( generate_stream_function = generate_stream_yuan2 elif is_cllm: generate_stream_function = generate_stream_cllm + elif is_gemma3: + generate_stream_function = generate_stream_gemma3 for x in generate_stream_function( model, tokenizer, @@ -818,6 +831,31 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class Gemma3Adapter(BaseModelAdapter): + """The model adapter for google/gemma-3""" + + def match(self, model_path: str): + return "gemma-3" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + device_map = from_pretrained_kwargs.get("device_map", None) + if device_map == "sequential": + device_map = "auto" + # print("From pretrained kwargs", from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = Gemma3ForCausalLM.from_pretrained( + model_path, + revision=revision, + torch_dtype=torch.bfloat16, + device_map=device_map, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemma") + + class KoalaAdapter(BaseModelAdapter): """The model adapter for Koala""" @@ -2502,6 +2540,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: # Note: the registration order matters. # The one registered earlier has a higher matching priority. +register_model_adapter(Gemma3Adapter) register_model_adapter(PeftModelAdapter) register_model_adapter(StableVicunaAdapter) register_model_adapter(VicunaAdapter) diff --git a/fastchat/model/model_gemma3.py b/fastchat/model/model_gemma3.py new file mode 100644 index 000000000..1b09f096e --- /dev/null +++ b/fastchat/model/model_gemma3.py @@ -0,0 +1,145 @@ +from threading import Thread +import gc +import torch +from transformers import TextIteratorStreamer + + +def generate_stream_gemma3( + model, + tokenizer, + params, + device, + context_len, + stream_interval=2, + judge_sent_end=False, +): + """Custom generate stream function for Gemma-3 models""" + # Get parameters from the request + prompt = params.get("prompt", "") + messages = params.get("messages", None) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + model_name = params.get("model", None) + + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + is_base_model = "pt" in model_name.lower() or "base" in model_name.lower() + + if not is_base_model: + # Format input based on whether we have messages or a plain prompt + if messages: + inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(model.device) + else: + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(model.device) + else: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + input_ids = inputs["input_ids"] + input_echo_len = input_ids.shape[1] + + # Configure generation parameters + generate_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0.0, + "temperature": temperature if temperature > 0.0 else 1.0, + } + + if top_p < 1.0: + generate_kwargs["top_p"] = top_p + if top_k > 0: + generate_kwargs["top_k"] = top_k + if repetition_penalty > 1.0: + generate_kwargs["repetition_penalty"] = repetition_penalty + + streamer = TextIteratorStreamer( + tokenizer, skip_prompt=not echo, skip_special_tokens=True + ) + generate_kwargs["streamer"] = streamer + + # Start generation in a separate thread + thread = Thread( + target=lambda: model.generate(input_ids=input_ids, **generate_kwargs) + ) + thread.start() + + # Track generation progress + generated_tokens = 0 + output_text = "" + + # Stream tokens + for new_text in streamer: + output_text += new_text + generated_tokens += 1 + + # Check for stop strings + should_stop = False + if stop_str: + if isinstance(stop_str, str): + if stop_str in output_text: + output_text = output_text[: output_text.find(stop_str)] + should_stop = True + elif isinstance(stop_str, list): + for stop in stop_str: + if stop in output_text: + output_text = output_text[: output_text.find(stop)] + should_stop = True + break + + # Stream at intervals or when stopping + if generated_tokens % stream_interval == 0 or should_stop: + yield { + "text": output_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": generated_tokens, + "total_tokens": input_echo_len + generated_tokens, + }, + "finish_reason": "stop" if should_stop else None, + } + + if should_stop: + break + + # Final output with finish reason + if thread.is_alive(): + thread.join( + timeout=3600 + ) # Arbitrary value, but if it doesn't complete in this much time then something is wrong + + yield { + "text": output_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": generated_tokens, + "total_tokens": input_echo_len + generated_tokens, + }, + "finish_reason": "length", + } + + # Clean up + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache()