1515from  transformers .utils  import  logging 
1616
1717from  vllm .config  import  VllmConfig 
18- from  vllm .distributed  import  get_pp_group ,  get_tensor_model_parallel_world_size 
18+ from  vllm .distributed  import  get_pp_group 
1919from  vllm .inputs  import  (INPUT_REGISTRY , DecoderOnlyInputs , DummyData ,
2020                         InputContext )
2121from  vllm .inputs .data  import  TokenInputs , token_inputs 
3434
3535from  .interfaces  import  SupportsLoRA , SupportsMultiModal 
3636from  .phi4mm_audio  import  AudioEmbedding 
37- from  .utils  import  maybe_prefix 
37+ from  .utils  import  AutoWeightsLoader ,  WeightsMapper ,  maybe_prefix 
3838from  .vision_siglip_navit  import  get_siglip_vision_model 
3939
4040# <|endoftext10|> (see vocab.json in hf model) 
@@ -352,12 +352,6 @@ def __init__(self,
352352        # n_embed or hidden_size 
353353        hidden_size  =  config .n_embd  if  hasattr (
354354            config , 'n_embd' ) else  config .hidden_size 
355-         if  hasattr (config , 'embd_pdrop' ) or  hasattr (config , 'embed_pdrop' ):
356-             embd_drop  =  config .embd_pdrop  if  hasattr (
357-                 config , 'embd_pdrop' ) else  config .embed_pdrop 
358-             self .drop  =  nn .Dropout (embd_drop )
359-         else :
360-             self .drop  =  None 
361355
362356        # layer_idx to output the img features 
363357        if  isinstance (config .img_processor , dict ):
@@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
14311425        ],
14321426    }
14331427
1428+     hf_to_vllm_mapper  =  WeightsMapper (
1429+         orig_to_new_substr = {
1430+             "base_layer." : "" ,
1431+         },
1432+         orig_to_new_prefix = {
1433+             "model.embed_tokens_extend.audio_embed.audio_projection.vision." :
1434+             "embed_tokens_extend.audio_projection_for_vision." ,
1435+             "model.embed_tokens_extend.audio_embed.audio_projection.speech." :
1436+             "embed_tokens_extend.audio_projection." ,
1437+             "model.embed_tokens_extend.audio_embed." : "embed_tokens_extend." ,
1438+             "model.embed_tokens_extend.image_embed." : "vision_encoder." ,
1439+         },
1440+     )
1441+ 
14341442    def  __init__ (self , * , vllm_config : VllmConfig , prefix : str  =  "" ):
14351443        super ().__init__ ()
14361444        config  =  vllm_config .model_config .hf_config 
@@ -1445,8 +1453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14451453        self .lora_config  =  lora_config 
14461454
14471455        # Tensor/Pipeline parallel not supported for now. 
1448-         assert  get_tensor_model_parallel_world_size (
1449-         ) ==  1 , "tensor parallel is not supported" 
14501456        assert  get_pp_group (
14511457        ).world_size  ==  1 , "pipeline parallel is not supported" 
14521458
@@ -1686,44 +1692,6 @@ def merge_image_features_to_inputs_embeds(
16861692        )
16871693        return  merged_embeds 
16881694
1689-     def  load_weights (self , weights : Iterable [Tuple [str ,
1690-                                                    torch .Tensor ]]) ->  None :
1691-         weights  =  {name : weight  for  name , weight  in  weights }
1692-         adjusted_weights  =  {}
1693- 
1694-         for  name , weight  in  weights .items ():
1695-             # NOTE vision-speech tasks use a separate projection layer 
1696-             audio_proj_4v  =  \
1697-                 "model.embed_tokens_extend.audio_embed.audio_projection.vision" 
1698-             if  name .startswith (audio_proj_4v ):
1699-                 name  =  name .replace (
1700-                     audio_proj_4v ,
1701-                     "embed_tokens_extend.audio_projection_for_vision" )
1702- 
1703-             name  =  (name .replace (
1704-                 "model.embed_tokens_extend.audio_embed." \
1705-                     "audio_projection.speech." ,
1706-                 "embed_tokens_extend.audio_projection." ,
1707-             ).replace (
1708-                 "model.embed_tokens_extend.audio_embed." ,
1709-                 "embed_tokens_extend." ,
1710-             ).replace ("model.embed_tokens_extend.image_embed." ,
1711-                       "vision_encoder." ))
1712-             # NOTE: this is deal with LoRA injection, where `base_layer` 
1713-             # remains as the original layer in the model 
1714-             if  name .endswith (".base_layer.weight" ):
1715-                 name  =  name .replace (".base_layer.weight" , ".weight" )
1716-             adjusted_weights [name ] =  weight 
1717- 
1718-         missing_keys , unexpected_keys  =  self .load_state_dict (adjusted_weights ,
1719-                                                              strict = False )
1720-         logger .debug ("*** missing keys:" )
1721-         for  key  in  missing_keys :
1722-             logger .debug (key )
1723-         logger .debug ("**** unexpected keys:" )
1724-         for  key  in  unexpected_keys :
1725-             logger .debug (key )
1726- 
17271695    def  forward (
17281696        self ,
17291697        input_ids : torch .Tensor ,
@@ -1796,6 +1764,13 @@ def sample(
17961764        next_tokens  =  self .sampler (logits , sampling_metadata )
17971765        return  next_tokens 
17981766
1767+     def  load_weights (self , weights : Iterable [Tuple [str ,
1768+                                                    torch .Tensor ]]) ->  None :
1769+         weights  =  ((name , data ) for  name , data  in  weights 
1770+                    if  "lora"  not  in name )
1771+         loader  =  AutoWeightsLoader (self )
1772+         return  loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
1773+ 
17991774    def  get_mm_mapping (self ) ->  MultiModelKeys :
18001775        """ 
18011776        Get the module prefix in multimodal models 
@@ -1804,4 +1779,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
18041779            language_model = "model." ,
18051780            connector = ["audio_projection_for_vision" , "audio_projection" ],
18061781            tower_model = ["vision_encoder" , "embed_tokens_extend" ],
1807-         )
1782+         )
0 commit comments