@@ -1508,7 +1508,8 @@ def get_layers_start_end_indices(
15081508        if  (self .hf_text_config .model_type  ==  "deepseek_mtp" 
15091509                or  self .hf_config .model_type  ==  "mimo_mtp" 
15101510                or  self .hf_config .model_type  ==  "glm4_moe_mtp" 
1511-                 or  self .hf_config .model_type  ==  "ernie_mtp" ):
1511+                 or  self .hf_config .model_type  ==  "ernie_mtp" 
1512+                 or  self .hf_config .model_type  ==  "qwen3_next_mtp" ):
15121513            total_num_hidden_layers  =  getattr (self .hf_text_config ,
15131514                                              "num_nextn_predict_layers" , 0 )
15141515        else :
@@ -1571,15 +1572,28 @@ def get_num_layers_by_block_type(
15711572            if  attn_type_list :
15721573                return  sum (t  ==  1  for  t  in  attn_type_list [start :end ])
15731574
1574-             if  layers_block_type_value  is  None  and  attn_type_list  is  None :
1575+             # Hybrid model Qwen3Next 
1576+             layer_types_value  =  getattr (self .hf_config , "layer_types" , None )
1577+             if  layer_types_value  is  not None :
1578+                 if  getattr (block_type , "value" , block_type ) ==  "attention" :
1579+                     return  sum (t  ==  "full_attention" 
1580+                                for  t  in  layer_types_value [start :end ])
1581+                 elif  getattr (block_type , "value" ,
1582+                              block_type ) ==  "linear_attention" :
1583+                     return  sum (t  ==  "linear_attention" 
1584+                                for  t  in  layer_types_value [start :end ])
1585+                 else :
1586+                     return  sum (t  ==  getattr (block_type , "value" , block_type )
1587+                                for  t  in  layer_types_value [start :end ])
1588+ 
1589+             if  (layers_block_type_value  is  None  and  attn_type_list  is  None 
1590+                     and  layer_types_value  is  None ):
15751591                raise  ValueError (
15761592                    "The model is an hybrid without a" 
1577-                     "layers_block_type or an attn_type_list in the hf_config, " 
1578-                     "cannot determine the num of " 
1593+                     "layers_block_type or an attn_type_list, or a layer_types  " 
1594+                     "in the hf_config,  cannot determine the num of " 
15791595                    f"{ block_type .value }  )
15801596
1581-             return  sum (t  ==  1  for  t  in  attn_type_list [start :end ])
1582- 
15831597    def  get_mamba_chunk_size (self ) ->  Optional [int ]:
15841598        """ 
15851599        Returns the mamba chunk size if it exists 
@@ -1866,7 +1880,7 @@ def __post_init__(self):
18661880
18671881SpeculativeMethod  =  Literal ["ngram" , "eagle" , "eagle3" , "medusa" ,
18681882                            "mlp_speculator" , "draft_model" , "deepseek_mtp" ,
1869-                             "ernie_mtp" ]
1883+                             "ernie_mtp" ,  "qwen3_next_mtp" ]
18701884
18711885
18721886@config  
@@ -2007,7 +2021,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
20072021                "n_predict" : n_predict ,
20082022                "architectures" : ["ErnieMTPModel" ]
20092023            })
2010-             return  hf_config 
2024+ 
2025+         if  hf_config .model_type  ==  "qwen3_next" :
2026+             hf_config .model_type  =  "qwen3_next_mtp" 
2027+         if  hf_config .model_type  ==  "qwen3_next_mtp" :
2028+             n_predict  =  getattr (hf_config , "num_nextn_predict_layers" , None )
2029+             hf_config .update ({
2030+                 "n_predict" : n_predict ,
2031+                 "architectures" : ["Qwen3NextMTP" ]
2032+             })
20112033
20122034        return  hf_config 
20132035
@@ -2028,9 +2050,13 @@ def __post_init__(self):
20282050                (self .target_model_config .hf_text_config .model_type  \
20292051                        ==  "deepseek_v3"  or 
20302052                    self .target_model_config .hf_text_config .model_type  in 
2031-                         ("mimo" ,"ernie4_5_moe" )):
2053+                         ("mimo" ,"ernie4_5_moe" ,  "qwen3_next" )):
20322054                # use the draft model from the same model: 
20332055                self .model  =  self .target_model_config .model 
2056+                 # Align the quantization of draft model for cases such as 
2057+                 # --quantization fp8 with a bf16 checkpoint. 
2058+                 if  not  self .quantization :
2059+                     self .quantization  =  self .target_model_config .quantization 
20342060            elif  self .method  in  ("ngram" , "[ngram]" ):
20352061                self .model  =  "ngram" 
20362062            else :
@@ -2140,6 +2166,15 @@ def __post_init__(self):
21402166                                "one layer. Might need some code changes "  \
21412167                                "to support multiple layers." 
21422168                            )
2169+                 elif  (self .draft_model_config .hf_config .model_type  == 
2170+                       "qwen3_next_mtp" ):
2171+                     self .method  =  "qwen3_next_mtp" 
2172+                     if  self .num_speculative_tokens  >  1 :
2173+                         logger .warning (
2174+                                 "All Qwen3Next MTP models only have "  \
2175+                                 "one layer. Might need some code changes "  \
2176+                                 "to support multiple layers." 
2177+                             )
21432178                else :
21442179                    self .method  =  "draft_model" 
21452180                    raise  NotImplementedError (
@@ -2355,7 +2390,8 @@ def num_lookahead_slots(self) -> int:
23552390        return  self .num_speculative_tokens 
23562391
23572392    def  use_eagle (self ) ->  bool :
2358-         return  self .method  in  ("eagle" , "eagle3" , "deepseek_mtp" , "ernie_mtp" )
2393+         return  self .method  in  ("eagle" , "eagle3" , "deepseek_mtp" , "ernie_mtp" ,
2394+                                "qwen3_next_mtp" )
23592395
23602396    def  __repr__ (self ) ->  str :
23612397        method  =  self .method 
0 commit comments