@@ -93,13 +93,15 @@ class ModelBase:
9393 # Mistral format specifics
9494 is_mistral_format : bool = False
9595 disable_mistral_community_chat_template : bool = False
96+ sentence_transformers_dense_modules : bool = False
9697
9798 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
9899 use_temp_file : bool = False , eager : bool = False ,
99100 metadata_override : Path | None = None , model_name : str | None = None ,
100101 split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
101102 small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ,
102- disable_mistral_community_chat_template : bool = False ):
103+ disable_mistral_community_chat_template : bool = False ,
104+ sentence_transformers_dense_modules : bool = False ):
103105 if type (self ) is ModelBase or \
104106 type (self ) is TextModel or \
105107 type (self ) is MmprojModel :
@@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
114116 self .lazy = not eager or (remote_hf_model_id is not None )
115117 self .dry_run = dry_run
116118 self .remote_hf_model_id = remote_hf_model_id
119+ self .sentence_transformers_dense_modules = sentence_transformers_dense_modules
117120 if remote_hf_model_id is not None :
118121 self .is_safetensors = True
119122
@@ -5299,6 +5302,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
52995302@ModelBase .register ("Gemma3TextModel" )
53005303class EmbeddingGemma (Gemma3Model ):
53015304 model_arch = gguf .MODEL_ARCH .GEMMA_EMBEDDING
5305+ module_paths = []
5306+ dense_features_dims = {}
5307+
5308+ def __init__ (self , * args , ** kwargs ):
5309+ super ().__init__ (* args , ** kwargs )
5310+ if self .sentence_transformers_dense_modules :
5311+ # read modules.json to determine if model has Dense layers
5312+ modules_file = self .dir_model / "modules.json"
5313+ if modules_file .is_file ():
5314+ with open (modules_file , encoding = "utf-8" ) as modules_json_file :
5315+ mods = json .load (modules_json_file )
5316+ for mod in mods :
5317+ if mod ["type" ] == "sentence_transformers.models.Dense" :
5318+ mod_path = mod ["path" ]
5319+ # check if model.safetensors file for Dense layer exists
5320+ model_tensors_file = self .dir_model / mod_path / "model.safetensors"
5321+ if model_tensors_file .is_file ():
5322+ self .module_paths .append (mod_path )
5323+ # read config.json of the Dense layer to get in/out features
5324+ mod_conf_file = self .dir_model / mod_path / "config.json"
5325+ if mod_conf_file .is_file ():
5326+ with open (mod_conf_file , encoding = "utf-8" ) as mod_conf_json_file :
5327+ mod_conf = json .load (mod_conf_json_file )
5328+ # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
5329+ prefix = self ._get_dense_prefix (mod_path )
5330+ if mod_conf ["in_features" ] is not None and mod_conf ["out_features" ] is not None :
5331+ self .dense_features_dims [prefix ] = (mod_conf ["in_features" ], mod_conf ["out_features" ])
5332+
5333+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
5334+ from safetensors .torch import load_file
5335+ module_paths = list (self .module_paths )
5336+ for i , module_path in enumerate (module_paths ):
5337+ tensors_file = self .dir_model / module_path / "model.safetensors"
5338+ local_tensors = load_file (tensors_file )
5339+ tensor_name = self ._get_dense_prefix (module_path )
5340+ for name , local_tensor in local_tensors .items ():
5341+ if not name .endswith (".weight" ):
5342+ continue
5343+ orig_name = name .replace ("linear" , tensor_name )
5344+ name = self .map_tensor_name (orig_name )
5345+ yield name , local_tensor .clone ()
5346+
5347+ @staticmethod
5348+ def _get_dense_prefix (module_path ) -> str :
5349+ """Get the tensor name prefix for the Dense layer from module path."""
5350+ tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
5351+ return tensor_name
53025352
53035353 def set_gguf_parameters (self ):
53045354 super ().set_gguf_parameters ()
@@ -5315,6 +5365,10 @@ def set_gguf_parameters(self):
53155365 logger .info (f"Using original sliding_window from config: { orig_sliding_window } "
53165366 f"instead of { self .hparams ['sliding_window' ]} " )
53175367 self .gguf_writer .add_sliding_window (orig_sliding_window )
5368+ if self .sentence_transformers_dense_modules :
5369+ for dense , dims in self .dense_features_dims .items ():
5370+ logger .info (f"Setting dense layer { dense } in/out features to { dims } " )
5371+ self .gguf_writer .add_dense_features_dims (dense , dims [0 ], dims [1 ])
53185372
53195373 self ._try_set_pooling_type ()
53205374
@@ -9365,6 +9419,13 @@ def parse_args() -> argparse.Namespace:
93659419 )
93669420 )
93679421
9422+ parser .add_argument (
9423+ "--sentence-transformers-dense-modules" , action = "store_true" ,
9424+ help = ("Whether to include sentence-transformers dense modules."
9425+ "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
9426+ "Default these modules are not included." )
9427+ )
9428+
93689429 args = parser .parse_args ()
93699430 if not args .print_supported_models and args .model is None :
93709431 parser .error ("the following arguments are required: model" )
@@ -9427,9 +9488,13 @@ def main() -> None:
94279488 if args .remote :
94289489 hf_repo_id = args .model
94299490 from huggingface_hub import snapshot_download
9491+ allowed_patterns = ["LICENSE" , "*.json" , "*.md" , "*.txt" , "tokenizer.model" ]
9492+ if args .sentence_transformers_dense_modules :
9493+ # include sentence-transformers dense modules safetensors files
9494+ allowed_patterns .append ("*.safetensors" )
94309495 local_dir = snapshot_download (
94319496 repo_id = hf_repo_id ,
9432- allow_patterns = [ "LICENSE" , "*.json" , "*.md" , "*.txt" , "tokenizer.model" ] )
9497+ allow_patterns = allowed_patterns )
94339498 dir_model = Path (local_dir )
94349499 logger .info (f"Downloaded config and tokenizer to { local_dir } " )
94359500 else :
@@ -9497,7 +9562,8 @@ def main() -> None:
94979562 split_max_tensors = args .split_max_tensors ,
94989563 split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
94999564 small_first_shard = args .no_tensor_first_split ,
9500- remote_hf_model_id = hf_repo_id , disable_mistral_community_chat_template = disable_mistral_community_chat_template
9565+ remote_hf_model_id = hf_repo_id , disable_mistral_community_chat_template = disable_mistral_community_chat_template ,
9566+ sentence_transformers_dense_modules = args .sentence_transformers_dense_modules
95019567 )
95029568
95039569 if args .vocab_only :
0 commit comments