diff --git a/keras_hub/src/utils/transformers/export/gemma3.py b/keras_hub/src/utils/transformers/export/gemma3.py new file mode 100644 index 0000000000..8120de7f19 --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gemma3.py @@ -0,0 +1,181 @@ +import keras.ops as ops + + +def get_gemma3_config(backbone): + """Convert Keras Gemma3 config to Hugging Face config dictionary.""" + token_embedding_layer = backbone.get_layer("token_embedding") + hf_config = { + "architectures": ["Gemma3ForCausalLM"], + "model_type": "gemma3_text", + "vocab_size": backbone.vocabulary_size, + "num_hidden_layers": backbone.num_layers, + "num_attention_heads": backbone.num_query_heads, + "num_key_value_heads": backbone.num_key_value_heads, + "hidden_size": backbone.hidden_dim, + "intermediate_size": backbone.intermediate_dim, + "head_dim": backbone.head_dim, + "max_position_embeddings": 32768, + "tie_word_embeddings": token_embedding_layer.tie_weights, + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + } + return hf_config + + +def get_gemma3_weights_map(backbone, include_lm_head=False): + """Convert a Keras Gemma3 model to Hugging Face format. + + include_lm_head: If True, exports for CausalLM (with "model." prefix). + If False, exports for backbone only (without prefix). + """ + + def _convert_qkv_kernel(kernel, hidden_dim): + """Helper to convert Q/K/V projection kernels to HF format. + + Args: + kernel: The kernel weight tensor to convert. + hidden_dim: The hidden dimension size for reshaping. + + Returns: + Converted kernel in HF format. + """ + kernel = ops.transpose(kernel, axes=(1, 0, 2)) # permute(1, 0, 2) + kernel = ops.reshape(kernel, (hidden_dim, -1)) + kernel = ops.transpose(kernel) # .T + return kernel + + weights_dict = {} + + # For CausalLM export, use "model." prefix + # For backbone export, use no prefix + prefix = "model." if include_lm_head else "" + + # Token embeddings - use .weights[0] to get backend tensor + token_embedding_layer = backbone.get_layer("token_embedding") + token_embedding = token_embedding_layer.weights[0] + weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding + + for i in range(backbone.num_layers): + block = backbone.get_layer(f"decoder_block_{i}") + + # Attention query projection + q_kernel = _convert_qkv_kernel( + block.attention.query_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel + + # Attention key projection + k_kernel = _convert_qkv_kernel( + block.attention.key_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel + + # Attention value projection + v_kernel = _convert_qkv_kernel( + block.attention.value_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel + + # Attention output projection + o_kernel = block.attention.output_dense.weights[0] + o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1) + o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1)) + weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel + + # Query and key normalization + q_norm = block.attention.query_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm + + k_norm = block.attention.key_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm + + # MLP gate projection + gate_kernel = block.gating_ffw.weights[0] + gate_kernel = ops.transpose(gate_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel + + # MLP up projection + up_kernel = block.gating_ffw_2.weights[0] + up_kernel = ops.transpose(up_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel + + # MLP down projection + down_kernel = block.ffw_linear.weights[0] + down_kernel = ops.transpose(down_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel + + # Pre-attention normalization + input_layer_norm = block.pre_attention_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = ( + input_layer_norm + ) + + # Post-attention normalization + if hasattr(block, "post_attention_norm"): + post_attn_norm = block.post_attention_norm.weights[0] + weights_dict[ + f"{prefix}layers.{i}.post_attention_layernorm.weight" + ] = post_attn_norm + # Pre-feedforward normalization + pre_feedforward_layernorm = block.pre_ffw_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = ( + pre_feedforward_layernorm + ) + # Post-feedforward normalization + if hasattr(block, "post_ffw_norm"): + post_feedforward_layernorm = block.post_ffw_norm.weights[0] + weights_dict[ + f"{prefix}layers.{i}.post_feedforward_layernorm.weight" + ] = post_feedforward_layernorm + + # Final normalization + final_norm = backbone.get_layer("final_normalization").weights[0] + weights_dict[f"{prefix}norm.weight"] = final_norm + + if include_lm_head and not token_embedding_layer.tie_weights: + weights_dict["lm_head.weight"] = ops.transpose( + token_embedding_layer.reverse_embeddings + ) + + return weights_dict + + +def get_gemma3_tokenizer_config(tokenizer): + tokenizer_config = { + "tokenizer_class": "GemmaTokenizer", + "clean_up_tokenization_spaces": False, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + "add_bos_token": True, + "add_eos_token": False, + "model_max_length": 32768, + } + # Add added_tokens_decoder + added_tokens_decoder = {} + special_tokens = [ + "", + "", + "", + "", + "", + "", + "", + ] + for token in special_tokens: + token_id = tokenizer.token_to_id(token) + if token_id is not None: + added_tokens_decoder[str(token_id)] = { + "content": token, + "special": True, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + } + tokenizer_config["added_tokens_decoder"] = added_tokens_decoder + return tokenizer_config diff --git a/keras_hub/src/utils/transformers/export/gemma3_test.py b/keras_hub/src/utils/transformers/export/gemma3_test.py new file mode 100644 index 0000000000..85ada4e69c --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gemma3_test.py @@ -0,0 +1,164 @@ +import os + +import numpy as np +from transformers import AutoModel +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM +from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( + Gemma3CausalLMPreprocessor, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class TestGemma3Export(TestCase): + def test_export_to_hf(self): + proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm") + tokenizer = Gemma3Tokenizer(proto=proto) + + # Create a small backbone (text-only, no vision encoder) + backbone = Gemma3Backbone( + vocabulary_size=tokenizer.vocabulary_size(), + image_size=896, + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=128, + intermediate_dim=256, + head_dim=64, + query_head_dim_normalize=True, + use_query_key_norm=True, + use_post_ffw_norm=True, + use_post_attention_norm=True, + attention_logit_soft_cap=None, + final_logit_soft_cap=None, + use_sliding_window_attention=False, + sliding_window_size=4096, + vision_encoder=None, # TODO: enable for vision models + layer_norm_epsilon=1e-6, + dropout=0, + ) + + # Create preprocessor + preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer) + + # Create the causal LM model + keras_model = Gemma3CausalLM( + backbone=backbone, preprocessor=preprocessor + ) + + # Set all weights to random values + rng = np.random.default_rng(42) + weights = keras_model.get_weights() + for i in range(len(weights)): + weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype) + keras_model.set_weights(weights) + + # Export to Hugging Face format using the new methods + export_path_backbone = os.path.join( + self.get_temp_dir(), "export_backbone" + ) + backbone.export_to_transformers(export_path_backbone) + + export_path_tokenizer = os.path.join( + self.get_temp_dir(), "export_tokenizer" + ) + preprocessor.tokenizer.export_to_transformers(export_path_tokenizer) + + export_path_task = os.path.join(self.get_temp_dir(), "export_task") + keras_model.export_to_transformers(export_path_task) + + # Load Hugging Face models and tokenizer + hf_backbone = AutoModel.from_pretrained(export_path_backbone) + # Note: We only test the slow tokenizer because the test vocab file + # is not compatible with the fast tokenizer (Unigram vs BPE mismatch). + # Using fast tokenizer raises: "You're trying to run a `Unigram` model + # but you're file was trained with a different algorithm" + hf_tokenizer_slow = AutoTokenizer.from_pretrained( + export_path_tokenizer, use_fast=False + ) + hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task) + + # Verify configuration + hf_config = hf_backbone.config + self.assertEqual( + hf_config.vocab_size, + backbone.vocabulary_size, + "Vocabulary sizes do not match", + ) + self.assertEqual( + hf_config.num_hidden_layers, + backbone.num_layers, + "Number of layers do not match", + ) + self.assertEqual( + hf_config.num_attention_heads, + backbone.num_query_heads, + "Number of query heads do not match", + ) + self.assertEqual( + hf_config.num_key_value_heads, + backbone.num_key_value_heads, + "Number of key value heads do not match", + ) + self.assertEqual( + hf_config.hidden_size, + backbone.hidden_dim, + "Hidden dimensions do not match", + ) + self.assertEqual( + hf_config.intermediate_size, + backbone.intermediate_dim, + "Intermediate sizes do not match", + ) + self.assertEqual( + hf_config.head_dim, + backbone.head_dim, + "Head dimensions do not match", + ) + self.assertEqual( + hf_config.max_position_embeddings, + 32768, + "Max position embeddings do not match", + ) + self.assertEqual( + hf_config.tie_word_embeddings, + backbone.token_embedding.tie_weights, + "Tie word embeddings do not match", + ) + + # Verify tokenizer compatibility (using slow tokenizer) + self.assertEqual( + hf_tokenizer_slow.vocab_size, + tokenizer.vocabulary_size(), + "Tokenizer vocabulary sizes do not match", + ) + + # Compare generated outputs using full model + # Test with small input since we set the seed, we expect same outcome + prompt = "the quick" + + # Generate with Keras model + keras_output = keras_model.generate(prompt, max_length=20) + + # Generate with HuggingFace model using slow tokenizer + input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt") + output_ids_slow = hf_full_model.generate( + input_ids_slow, max_length=20, do_sample=False + ) + hf_slow_output = hf_tokenizer_slow.decode( + output_ids_slow[0], skip_special_tokens=True + ) + + # Debug print to see the actual outputs + print(f"Keras output: '{keras_output}'") + print(f"HF slow output: '{hf_slow_output}'") + + self.assertEqual( + keras_output, + hf_slow_output, + "Generated outputs do not match (slow)", + ) diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index 1593987ca9..b3a55fb27b 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -10,19 +10,29 @@ get_gemma_tokenizer_config, ) from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map +from keras_hub.src.utils.transformers.export.gemma3 import get_gemma3_config +from keras_hub.src.utils.transformers.export.gemma3 import ( + get_gemma3_tokenizer_config, +) +from keras_hub.src.utils.transformers.export.gemma3 import ( + get_gemma3_weights_map, +) MODEL_CONFIGS = { "GemmaBackbone": get_gemma_config, + "Gemma3Backbone": get_gemma3_config, # Add for future models, e.g., "MistralBackbone": get_mistral_config } MODEL_EXPORTERS = { "GemmaBackbone": get_gemma_weights_map, + "Gemma3Backbone": get_gemma3_weights_map, # Add for future models, e.g., "MistralBackbone": get_mistral_weights_map } MODEL_TOKENIZER_CONFIGS = { "GemmaTokenizer": get_gemma_tokenizer_config, + "Gemma3Tokenizer": get_gemma3_tokenizer_config, # Add for future models, e.g., "MistralTokenizer": # get_mistral_tokenizer_config }