-
Notifications
You must be signed in to change notification settings - Fork 309
Gemma3 text keras hf checkpoint conversion #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 9 commits
938b52b
1f06acb
24c9573
71bb3af
85f9498
69a7137
525da45
06ed2ad
ab1bde1
1ec7222
10b1439
55ccec8
2b63d71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| import keras.ops as ops | ||
|
|
||
|
|
||
| def get_gemma3_config(backbone): | ||
| """Convert Keras Gemma3 config to Hugging Face config dictionary.""" | ||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||
| hf_config = { | ||
| "architectures": ["Gemma3ForCausalLM"], | ||
| "model_type": "gemma3_text", | ||
| "vocab_size": backbone.vocabulary_size, | ||
| "num_hidden_layers": backbone.num_layers, | ||
| "num_attention_heads": backbone.num_query_heads, | ||
| "num_key_value_heads": backbone.num_key_value_heads, | ||
| "hidden_size": backbone.hidden_dim, | ||
| "intermediate_size": backbone.intermediate_dim, | ||
| "head_dim": backbone.head_dim, | ||
| "max_position_embeddings": 32768, | ||
| "tie_word_embeddings": token_embedding_layer.tie_weights, | ||
| "rms_norm_eps": 1e-6, | ||
| "rope_theta": 10000.0, | ||
| "attention_bias": False, | ||
| "attention_dropout": 0.0, | ||
| "hidden_activation": "gelu_pytorch_tanh", | ||
| } | ||
| return hf_config | ||
|
|
||
|
|
||
| def get_gemma3_weights_map(backbone, include_lm_head=False): | ||
| """Convert a Keras Gemma3 model to Hugging Face format. | ||
|
|
||
| include_lm_head: If True, exports for CausalLM (with "model." prefix). | ||
| If False, exports for backbone only (without prefix). | ||
| """ | ||
|
|
||
| weights_dict = {} | ||
|
|
||
| # For CausalLM export, use "model." prefix | ||
| # For backbone export, use no prefix | ||
| prefix = "model." if include_lm_head else "" | ||
|
|
||
| # Token embeddings - use .weights[0] to get backend tensor | ||
| token_embedding_layer = backbone.get_layer("token_embedding") | ||
| token_embedding = token_embedding_layer.weights[0] | ||
| weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding | ||
|
|
||
| for i in range(backbone.num_layers): | ||
| block = backbone.get_layer(f"decoder_block_{i}") | ||
|
|
||
| # Attention query projection | ||
| q_kernel = block.attention.query_dense.weights[0] | ||
| q_kernel = ops.transpose(q_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| q_kernel = ops.reshape(q_kernel, (backbone.hidden_dim, -1)) | ||
| q_kernel = ops.transpose(q_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel | ||
|
|
||
| # Attention key projection | ||
| k_kernel = block.attention.key_dense.weights[0] | ||
| k_kernel = ops.transpose(k_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| k_kernel = ops.reshape(k_kernel, (backbone.hidden_dim, -1)) | ||
| k_kernel = ops.transpose(k_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel | ||
|
|
||
| # Attention value projection | ||
| v_kernel = block.attention.value_dense.weights[0] | ||
| v_kernel = ops.transpose(v_kernel, axes=(1, 0, 2)) # permute(1, 0, 2) | ||
| v_kernel = ops.reshape(v_kernel, (backbone.hidden_dim, -1)) | ||
| v_kernel = ops.transpose(v_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel | ||
|
|
||
| # Attention output projection | ||
| o_kernel = block.attention.output_dense.weights[0] | ||
| o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1) | ||
| o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1)) | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel | ||
|
|
||
| # Query and key normalization | ||
| q_norm = block.attention.query_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm | ||
|
|
||
| k_norm = block.attention.key_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm | ||
|
|
||
| # MLP gate projection | ||
| gate_kernel = block.gating_ffw.weights[0] | ||
| gate_kernel = ops.transpose(gate_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
|
||
| # MLP up projection | ||
| up_kernel = block.gating_ffw_2.weights[0] | ||
| up_kernel = ops.transpose(up_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
|
||
| # MLP down projection | ||
| down_kernel = block.ffw_linear.weights[0] | ||
| down_kernel = ops.transpose(down_kernel) # .T | ||
| weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
|
||
| # Pre-attention normalization | ||
| input_layer_norm = block.pre_attention_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = ( | ||
| input_layer_norm | ||
| ) | ||
|
|
||
| # Post-attention normalization | ||
| if hasattr(block, "post_attention_norm"): | ||
| post_attn_norm = block.post_attention_norm.weights[0] | ||
| else: | ||
| # Fallback to pre_ffw_norm if post_attention_norm doesn't exist | ||
| post_attn_norm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.post_attention_layernorm.weight"] = ( | ||
| post_attn_norm | ||
| ) | ||
|
|
||
| # Pre-feedforward normalization | ||
| pre_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = ( | ||
| pre_feedforward_layernorm | ||
| ) | ||
|
|
||
| # Post-feedforward normalization | ||
| if hasattr(block, "post_ffw_norm"): | ||
| post_feedforward_layernorm = block.post_ffw_norm.weights[0] | ||
| else: | ||
| # Fallback to pre_ffw_norm if post_ffw_norm doesn't exist | ||
| post_feedforward_layernorm = block.pre_ffw_norm.weights[0] | ||
| weights_dict[ | ||
| f"{prefix}layers.{i}.post_feedforward_layernorm.weight" | ||
| ] = post_feedforward_layernorm | ||
kharshith-k marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Final normalization | ||
| final_norm = backbone.get_layer("final_normalization").weights[0] | ||
| weights_dict[f"{prefix}norm.weight"] = final_norm | ||
|
|
||
| if include_lm_head and not token_embedding_layer.tie_weights: | ||
| weights_dict["lm_head.weight"] = ops.transpose( | ||
| token_embedding_layer.reverse_embeddings | ||
| ) | ||
|
|
||
| return weights_dict | ||
|
|
||
|
|
||
| def get_gemma3_tokenizer_config(tokenizer): | ||
| tokenizer_config = { | ||
| "tokenizer_class": "GemmaTokenizer", | ||
| "clean_up_tokenization_spaces": False, | ||
| "bos_token": "<bos>", | ||
| "eos_token": "<eos>", | ||
| "pad_token": "<pad>", | ||
| "unk_token": "<unk>", | ||
| "add_bos_token": True, | ||
| "add_eos_token": False, | ||
| "model_max_length": 32768, | ||
| } | ||
| # Add added_tokens_decoder | ||
| added_tokens_decoder = {} | ||
| special_tokens = [ | ||
| "<pad>", | ||
| "<bos>", | ||
| "<eos>", | ||
| "<unk>", | ||
| "<start_of_image>", | ||
| "<end_of_image>", | ||
| "<img>", | ||
| ] | ||
| for token in special_tokens: | ||
| token_id = tokenizer.token_to_id(token) | ||
| if token_id is not None: | ||
| added_tokens_decoder[str(token_id)] = { | ||
| "content": token, | ||
| "special": True, | ||
| "single_word": False, | ||
| "lstrip": False, | ||
| "rstrip": False, | ||
| "normalized": False, | ||
| } | ||
| tokenizer_config["added_tokens_decoder"] = added_tokens_decoder | ||
| return tokenizer_config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import os | ||
|
|
||
| import numpy as np | ||
| from transformers import AutoModel | ||
| from transformers import AutoModelForCausalLM | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM | ||
| from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( | ||
| Gemma3CausalLMPreprocessor, | ||
| ) | ||
| from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class TestGemma3Export(TestCase): | ||
| def test_export_to_hf(self): | ||
| proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm") | ||
| tokenizer = Gemma3Tokenizer(proto=proto) | ||
|
|
||
| # Create a small backbone (text-only, no vision encoder) | ||
| backbone = Gemma3Backbone( | ||
| vocabulary_size=tokenizer.vocabulary_size(), | ||
| image_size=896, # Default value even for text-only | ||
| num_layers=2, | ||
| num_query_heads=4, | ||
| num_key_value_heads=1, | ||
| hidden_dim=512, | ||
| intermediate_dim=1028, | ||
| head_dim=128, | ||
| query_head_dim_normalize=True, | ||
| use_query_key_norm=True, | ||
| use_post_ffw_norm=True, # Real Gemma3 models have these | ||
| use_post_attention_norm=True, # Real Gemma3 models have these | ||
| attention_logit_soft_cap=None, | ||
| final_logit_soft_cap=None, | ||
| use_sliding_window_attention=False, | ||
| sliding_window_size=4096, | ||
| vision_encoder=None, # Text-only model for testing | ||
kharshith-k marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| layer_norm_epsilon=1e-6, | ||
| dropout=0, | ||
| ) | ||
kharshith-k marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Create preprocessor | ||
| preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer) | ||
|
|
||
| # Create the causal LM model | ||
| keras_model = Gemma3CausalLM( | ||
| backbone=backbone, preprocessor=preprocessor | ||
| ) | ||
|
|
||
| # Set all weights to random values | ||
| rng = np.random.default_rng(42) | ||
| weights = keras_model.get_weights() | ||
| for i in range(len(weights)): | ||
| weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype) | ||
| keras_model.set_weights(weights) | ||
|
|
||
| # Export to Hugging Face format using the new methods | ||
| export_path_backbone = os.path.join( | ||
| self.get_temp_dir(), "export_backbone" | ||
| ) | ||
| backbone.export_to_transformers(export_path_backbone) | ||
|
|
||
| export_path_tokenizer = os.path.join( | ||
| self.get_temp_dir(), "export_tokenizer" | ||
| ) | ||
| preprocessor.tokenizer.export_to_transformers(export_path_tokenizer) | ||
|
|
||
| export_path_task = os.path.join(self.get_temp_dir(), "export_task") | ||
| keras_model.export_to_transformers(export_path_task) | ||
|
|
||
| # Load Hugging Face models and tokenizer | ||
| # Note: We only test the slow tokenizer because the test vocab file | ||
| # may not be compatible with fast tokenizer conversion | ||
|
||
| hf_backbone = AutoModel.from_pretrained(export_path_backbone) | ||
| hf_tokenizer_slow = AutoTokenizer.from_pretrained( | ||
| export_path_tokenizer, use_fast=False | ||
| ) | ||
| hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task) | ||
|
|
||
| # Verify configuration | ||
| hf_config = hf_backbone.config | ||
| self.assertEqual( | ||
| hf_config.vocab_size, | ||
| backbone.vocabulary_size, | ||
| "Vocabulary sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_hidden_layers, | ||
| backbone.num_layers, | ||
| "Number of layers do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_attention_heads, | ||
| backbone.num_query_heads, | ||
| "Number of query heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.num_key_value_heads, | ||
| backbone.num_key_value_heads, | ||
| "Number of key value heads do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.hidden_size, | ||
| backbone.hidden_dim, | ||
| "Hidden dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.intermediate_size, | ||
| backbone.intermediate_dim, | ||
| "Intermediate sizes do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.head_dim, | ||
| backbone.head_dim, | ||
| "Head dimensions do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.max_position_embeddings, | ||
| 32768, | ||
| "Max position embeddings do not match", | ||
| ) | ||
| self.assertEqual( | ||
| hf_config.tie_word_embeddings, | ||
| backbone.token_embedding.tie_weights, | ||
| "Tie word embeddings do not match", | ||
| ) | ||
|
|
||
| # Verify tokenizer compatibility (using slow tokenizer) | ||
| self.assertEqual( | ||
| hf_tokenizer_slow.vocab_size, | ||
| tokenizer.vocabulary_size(), | ||
| "Tokenizer vocabulary sizes do not match", | ||
| ) | ||
|
|
||
| # Compare generated outputs using full model | ||
| prompt = "the quick" | ||
|
|
||
| # Generate with Keras model | ||
| keras_output = keras_model.generate(prompt, max_length=20) | ||
|
|
||
| # Generate with HuggingFace model using slow tokenizer | ||
| input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt") | ||
| output_ids_slow = hf_full_model.generate( | ||
| input_ids_slow, max_length=20, do_sample=False | ||
| ) | ||
| hf_slow_output = hf_tokenizer_slow.decode( | ||
| output_ids_slow[0], skip_special_tokens=True | ||
| ) | ||
|
|
||
| # Debug print to see the actual outputs | ||
| print(f"Keras output: '{keras_output}'") | ||
| print(f"HF slow output: '{hf_slow_output}'") | ||
|
|
||
| self.assertEqual( | ||
| keras_output, | ||
| hf_slow_output, | ||
| "Generated outputs do not match", | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.