Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9fe9994
feat: Add Qwen3 omni moe thinker
Sep 24, 2025
93efc39
update registry and models page
Sep 24, 2025
81fd24b
Merge branch 'main' into dev/qwen3-omni-moe
DarkLight1337 Sep 27, 2025
f0d057a
Update w.r.t. #16229
DarkLight1337 Sep 27, 2025
c3e15a6
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 3, 2025
d59ac08
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 3, 2025
0b24c98
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 3, 2025
087a936
remove attn mask
ywang96 Oct 3, 2025
8ffc26e
update
ywang96 Oct 4, 2025
fb1d82b
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 5, 2025
7f42fb0
fix backend import
ywang96 Oct 5, 2025
7408b9c
fix prompt update
ywang96 Oct 5, 2025
8e1f5aa
yapf
ywang96 Oct 5, 2025
3c44f89
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 5, 2025
a985baa
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 6, 2025
0525d27
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 8, 2025
7484970
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 9, 2025
3c6243b
cleanup
ywang96 Oct 9, 2025
b8ec4d6
fix
ywang96 Oct 9, 2025
4c749d1
add
ywang96 Oct 9, 2025
650855a
fix
ywang96 Oct 9, 2025
03e1310
fix mixed modality
ywang96 Oct 9, 2025
7796103
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 9, 2025
2dee5f6
remove unnecessary tensor creation
ywang96 Oct 9, 2025
a6cb680
add note
ywang96 Oct 9, 2025
b4137ab
simplify
ywang96 Oct 9, 2025
d82c17e
cleanup
ywang96 Oct 9, 2025
677412d
cleanup
ywang96 Oct 9, 2025
bb00572
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 9, 2025
24c2c95
add guard
ywang96 Oct 9, 2025
14c6903
update
ywang96 Oct 9, 2025
51366bd
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 9, 2025
dc17c61
Merge branch 'main' into dev/qwen3-omni-moe
ywang96 Oct 10, 2025
ae0c930
fix
ywang96 Oct 10, 2025
3ec407f
add qwen3-omni processor test
Isotr0py Oct 10, 2025
897245a
fix non audio_in_video update
Isotr0py Oct 10, 2025
2465f49
fix qwen3-omni processor test
Isotr0py Oct 10, 2025
5d7ddf7
fix registry
Isotr0py Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ |
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,9 @@ def check_available_online(
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False),
"Qwen3OmniMoeThinkerForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
Expand Down
233 changes: 228 additions & 5 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,214 @@

return llm_positions, mrope_position_delta


@classmethod
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Comment on lines 1123 to 1132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _omni3_get_input_positions_tensor is very long and complex, making it difficult to understand and maintain. More importantly, it processes input sequences one by one within a for loop (for i, input_ids in enumerate(total_input_ids):), which is not vectorized and will lead to significant performance degradation, especially with larger batch sizes. The use of .tolist() and list methods like .index() inside the loop further contributes to the inefficiency. This implementation should be refactored to be vectorized over the batch dimension to meet the performance standards of vLLM. Consider using tensor operations to find indices and process modalities in parallel for all sequences in the batch.


def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
multimodal_nums = (

Check failure on line 1087 in vllm/model_executor/layers/rotary_embedding/mrope.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/rotary_embedding/mrope.py:1087:81: E501 Line too long (96 > 80)
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = ((torch.arange(grid_t)) * 1 * position_id_per_seconds)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id and not use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]:
llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1])
video_data_index += 1
else:
llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1])
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1

Check failure on line 1213 in vllm/model_executor/layers/rotary_embedding/mrope.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/rotary_embedding/mrope.py:1213:81: E501 Line too long (116 > 80)
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)

position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.long().max() + 1 - len(input_ids))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas.long()
else:
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
return position_ids, mrope_position_deltas.long()

@classmethod
def _omni_get_input_positions_tensor(
cls,
Expand Down Expand Up @@ -1060,7 +1268,27 @@
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.

model_type = hf_config.model_type
thinker_config = hf_config.thinker_config

if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)

if "qwen3_omni" in model_type:
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor(thinker_config,

Check failure on line 1280 in vllm/model_executor/layers/rotary_embedding/mrope.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/rotary_embedding/mrope.py:1280:81: E501 Line too long (90 > 80)
torch.tensor([input_tokens]),
image_grid_thw,
video_grid_thw,
None,
use_audio_in_video,
audio_feature_lengths,
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line creates a tensor in a highly inefficient way. torch.tensor(video_grid_thw) is redundant as video_grid_thw is already a tensor at this point. Creating a list of 1s and then converting it to a tensor is also inefficient. This can be simplified and made more performant.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(video_grid_thw.shape[0], dtype=torch.long, device=video_grid_thw.device))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(len(video_grid_thw))

Simplify this

llm_positions = llm_positions.squeeze(1)
mrope_position_delta = mrope_position_delta.squeeze()
return llm_positions, mrope_position_delta

audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
Expand All @@ -1073,11 +1301,6 @@
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)

if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)

src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
Expand Down Expand Up @@ -1129,7 +1352,7 @@
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1

Check failure on line 1355 in vllm/model_executor/layers/rotary_embedding/mrope.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/rotary_embedding/mrope.py:1355:81: E501 Line too long (88 > 80)
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
Expand Down
Loading
Loading