Skip to content

Commit e93f4cc

Browse files
Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent 2048c4e commit e93f4cc

29 files changed

+2476
-61
lines changed

.yapfignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
collect_env.py
2+
vllm/model_executor/layers/fla/ops/*.py

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ th {
403403
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
404404
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
405405
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
406+
| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
406407
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
407408
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
408409
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ fo = "fo"
228228
ba = "ba"
229229

230230
[tool.typos.type.py.extend-words]
231+
ba = "ba"
231232

232233
[tool.typos.type.cpp]
233234
extend-glob = ["*.cu"]

tests/models/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def check_available_online(
326326
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
327327
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
328328
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
329+
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
330+
min_transformers_version="4.56.2"),
329331
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
330332
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
331333
trust_remote_code=True,
@@ -640,7 +642,9 @@ def check_available_online(
640642
is_available_online=False),
641643
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
642644
trust_remote_code=True,
643-
speculative_model="XiaomiMiMo/MiMo-7B-RL")
645+
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
646+
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
647+
min_transformers_version="4.56.2"),
644648
}
645649

646650
_TRANSFORMERS_BACKEND_MODELS = {

vllm/config/__init__.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,8 @@ def get_layers_start_end_indices(
15081508
if (self.hf_text_config.model_type == "deepseek_mtp"
15091509
or self.hf_config.model_type == "mimo_mtp"
15101510
or self.hf_config.model_type == "glm4_moe_mtp"
1511-
or self.hf_config.model_type == "ernie_mtp"):
1511+
or self.hf_config.model_type == "ernie_mtp"
1512+
or self.hf_config.model_type == "qwen3_next_mtp"):
15121513
total_num_hidden_layers = getattr(self.hf_text_config,
15131514
"num_nextn_predict_layers", 0)
15141515
else:
@@ -1571,15 +1572,28 @@ def get_num_layers_by_block_type(
15711572
if attn_type_list:
15721573
return sum(t == 1 for t in attn_type_list[start:end])
15731574

1574-
if layers_block_type_value is None and attn_type_list is None:
1575+
# Hybrid model Qwen3Next
1576+
layer_types_value = getattr(self.hf_config, "layer_types", None)
1577+
if layer_types_value is not None:
1578+
if getattr(block_type, "value", block_type) == "attention":
1579+
return sum(t == "full_attention"
1580+
for t in layer_types_value[start:end])
1581+
elif getattr(block_type, "value",
1582+
block_type) == "linear_attention":
1583+
return sum(t == "linear_attention"
1584+
for t in layer_types_value[start:end])
1585+
else:
1586+
return sum(t == getattr(block_type, "value", block_type)
1587+
for t in layer_types_value[start:end])
1588+
1589+
if (layers_block_type_value is None and attn_type_list is None
1590+
and layer_types_value is None):
15751591
raise ValueError(
15761592
"The model is an hybrid without a"
1577-
"layers_block_type or an attn_type_list in the hf_config,"
1578-
"cannot determine the num of "
1593+
"layers_block_type or an attn_type_list, or a layer_types "
1594+
"in the hf_config, cannot determine the num of "
15791595
f"{block_type.value} layers")
15801596

1581-
return sum(t == 1 for t in attn_type_list[start:end])
1582-
15831597
def get_mamba_chunk_size(self) -> Optional[int]:
15841598
"""
15851599
Returns the mamba chunk size if it exists
@@ -1866,7 +1880,7 @@ def __post_init__(self):
18661880

18671881
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
18681882
"mlp_speculator", "draft_model", "deepseek_mtp",
1869-
"ernie_mtp"]
1883+
"ernie_mtp", "qwen3_next_mtp"]
18701884

18711885

18721886
@config
@@ -2007,7 +2021,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
20072021
"n_predict": n_predict,
20082022
"architectures": ["ErnieMTPModel"]
20092023
})
2010-
return hf_config
2024+
2025+
if hf_config.model_type == "qwen3_next":
2026+
hf_config.model_type = "qwen3_next_mtp"
2027+
if hf_config.model_type == "qwen3_next_mtp":
2028+
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
2029+
hf_config.update({
2030+
"n_predict": n_predict,
2031+
"architectures": ["Qwen3NextMTP"]
2032+
})
20112033

20122034
return hf_config
20132035

@@ -2028,9 +2050,13 @@ def __post_init__(self):
20282050
(self.target_model_config.hf_text_config.model_type \
20292051
== "deepseek_v3" or
20302052
self.target_model_config.hf_text_config.model_type in
2031-
("mimo","ernie4_5_moe")):
2053+
("mimo","ernie4_5_moe", "qwen3_next")):
20322054
# use the draft model from the same model:
20332055
self.model = self.target_model_config.model
2056+
# Align the quantization of draft model for cases such as
2057+
# --quantization fp8 with a bf16 checkpoint.
2058+
if not self.quantization:
2059+
self.quantization = self.target_model_config.quantization
20342060
elif self.method in ("ngram", "[ngram]"):
20352061
self.model = "ngram"
20362062
else:
@@ -2140,6 +2166,15 @@ def __post_init__(self):
21402166
"one layer. Might need some code changes " \
21412167
"to support multiple layers."
21422168
)
2169+
elif (self.draft_model_config.hf_config.model_type ==
2170+
"qwen3_next_mtp"):
2171+
self.method = "qwen3_next_mtp"
2172+
if self.num_speculative_tokens > 1:
2173+
logger.warning(
2174+
"All Qwen3Next MTP models only have " \
2175+
"one layer. Might need some code changes " \
2176+
"to support multiple layers."
2177+
)
21432178
else:
21442179
self.method = "draft_model"
21452180
raise NotImplementedError(
@@ -2355,7 +2390,8 @@ def num_lookahead_slots(self) -> int:
23552390
return self.num_speculative_tokens
23562391

23572392
def use_eagle(self) -> bool:
2358-
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
2393+
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
2394+
"qwen3_next_mtp")
23592395

23602396
def __repr__(self) -> str:
23612397
method = self.method

vllm/config/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ class CompilationConfig:
341341
"vllm.short_conv",
342342
"vllm.linear_attention",
343343
"vllm.plamo2_mamba_mixer",
344+
"vllm.gdn_attention",
344345
]
345346

346347
def compute_hash(self) -> str:

vllm/model_executor/layers/fla/ops/chunk_delta_h.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.triton_utils import tl, triton
1515

1616
from .index import prepare_chunk_indices, prepare_chunk_offsets
17-
from .op import exp, safe_exp
17+
from .op import exp
1818
from .utils import is_nvidia_hopper, use_cuda_graph
1919

2020
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
175175
boundary_check=(0, 1))
176176

177177
if USE_G:
178+
m_t = (i_t * BT + tl.arange(0, BT)) < T
178179
last_idx = min((i_t + 1) * BT, T) - 1
179180
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
180181
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
181182
(i_t * BT, ), (BT, ), (0, ))
182183
b_g = tl.load(p_g, boundary_check=(0, ))
183-
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
184+
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
184185
b_g_last = exp(b_g_last)
185186
b_h1 = b_h1 * b_g_last
186187
if K > 64:

vllm/model_executor/layers/fla/ops/chunk_o.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from vllm.triton_utils import tl, triton
1717

1818
from .index import prepare_chunk_indices
19-
from .op import exp, safe_exp
19+
from .op import exp
2020
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
2121

2222
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
@@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
112112
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
113113
b_g = tl.load(p_g, boundary_check=(0, ))
114114
b_o = b_o * exp(b_g)[:, None]
115-
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
115+
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
116116

117-
o_i = tl.arange(0, BT)
118-
m_A = o_i[:, None] >= o_i[None, :]
117+
o_t = i_t * BT + tl.arange(0, BT)
118+
m_t = o_t < T
119+
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
119120
b_A = tl.where(m_A, b_A, 0)
120121

121122
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),

vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.triton_utils import tl, triton
1515

1616
from .index import prepare_chunk_indices
17-
from .op import safe_exp
17+
from .op import exp
1818

1919

2020
@triton.heuristics({
@@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
5656
T = eos - bos
5757
else:
5858
bos, eos = i_b * T, i_b * T + T
59-
o_t = tl.arange(0, BT)
59+
o_t = i_t * BT + tl.arange(0, BT)
60+
m_t = o_t < T
6061

6162
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
6263
(i_t * BT, ), (BT, ), (0, ))
@@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
7677
(i_t * BT, ), (BT, ), (0, ))
7778
b_g = tl.load(p_g, boundary_check=(0, ))
7879
b_g_diff = b_g[:, None] - b_g[None, :]
79-
b_A = b_A * safe_exp(b_g_diff)
80+
b_A = b_A * exp(b_g_diff)
8081

81-
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
82+
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
83+
b_A = tl.where(m_A, b_A, 0)
8284
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
8385
(i_t * BT, 0), (BT, BT), (1, 0))
8486
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

vllm/model_executor/layers/fla/ops/fused_recurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
116116
b_g = tl.load(p_g).to(tl.float32)
117117

118118
if USE_QK_L2NORM_IN_KERNEL:
119-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
120-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
119+
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
120+
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
121121
b_q = b_q * scale
122122
# [BK, BV]
123123
b_h *= exp(b_g)

0 commit comments

Comments
 (0)