Skip to content

Commit c4cb0af

Browse files
[spec decode] Fix MTP inference path for MiMo-7B model (#25136)
Signed-off-by: zixi-qi <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 1c3b163 commit c4cb0af

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def parse_args():
5353
"--method",
5454
type=str,
5555
default="eagle",
56-
choices=["ngram", "eagle", "eagle3", "mtp"],
5756
)
5857
parser.add_argument("--num-spec-tokens", type=int, default=2)
5958
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@@ -118,6 +117,11 @@ def main():
118117
"prompt_lookup_max": args.prompt_lookup_max,
119118
"prompt_lookup_min": args.prompt_lookup_min,
120119
}
120+
elif args.method.endswith("mtp"):
121+
speculative_config = {
122+
"method": args.method,
123+
"num_speculative_tokens": args.num_spec_tokens,
124+
}
121125
else:
122126
raise ValueError(f"unknown method: {args.method}")
123127

vllm/config/speculative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
3333
"mlp_speculator", "draft_model", "deepseek_mtp",
34-
"ernie_mtp", "qwen3_next_mtp"]
34+
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
3535

3636

3737
@config

vllm/model_executor/models/mimo_mtp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,27 @@ def load_weights(self, weights: Iterable[tuple[str,
241241

242242
def map_model_name_to_mtp_param_name(self, name: str) -> str:
243243
import regex as re
244+
245+
# append mtp_start_layer_idx
246+
pattern = r"(model\.mtp_layers\.)(\d+)(\.)"
247+
match = re.match(pattern, name)
248+
if match:
249+
original_num = int(match.group(2))
250+
new_num = original_num + self.config.num_hidden_layers
251+
name = name.replace(match.group(), f"{match.group(1)}{new_num}.")
252+
# check for early turn
244253
name_without_prefix = [
245254
"token_layernorm", "hidden_layernorm", "input_proj",
246255
"final_layernorm"
247256
]
248257
for sub_name in name_without_prefix:
249258
if sub_name in name:
250259
return name
251-
pattern = r"model.mtp_layers.(\d+)."
252-
group = re.match(pattern, name)
253-
if group is not None:
254-
name = name.replace(group.group(), group.group() + "mtp_block.")
260+
# add mtp_block
261+
pattern = r"(model\.mtp_layers\.\d+\.)"
262+
match = re.match(pattern, name)
263+
if match:
264+
name = name.replace(match.group(), match.group() + "mtp_block.")
255265
return name
256266

257267
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:

0 commit comments

Comments
 (0)