Skip to content

Conversation

wangxiongts
Copy link
Contributor

@wangxiongts wangxiongts commented Sep 24, 2025

This PR from the Qwen team for: qwen3-omni-moe thinker part.

Testing has been conducted internally across four configurations (v0/v1, eager/CUDA) on several representative benchmarks, with results meeting expectations.

Known issues (we hope to resolve them together with the vLLM team):

  • In v1 mode, use_audio_in_video will raise errors because the video mm_data and placeholders is not updated.

We sincerely appreciate the great work and support from the vLLM team, and look forward to your feedback.

CLOSE #25472

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added new-model Requests to new models qwen Related to Qwen models labels Sep 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Qwen3-Omni-Moe model. The changes include a new model implementation file, modifications to handle multimodal rotary embeddings, and registration of the new model. While the implementation is comprehensive, I've identified several critical and high-severity issues related to performance and maintainability. Specifically, there are non-vectorized loops and inefficient tensor operations in the position embedding calculation, which will significantly impact performance. Additionally, there are uses of NumPy within core logic that should be replaced with PyTorch operations to avoid CPU-GPU synchronization. I've also found a few potential bugs related to tensor shape calculations that could lead to runtime errors. Addressing these points will be crucial for integrating this model into vLLM effectively.

Comment on lines 1033 to 1043
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _omni3_get_input_positions_tensor is very long and complex, making it difficult to understand and maintain. More importantly, it processes input sequences one by one within a for loop (for i, input_ids in enumerate(total_input_ids):), which is not vectorized and will lead to significant performance degradation, especially with larger batch sizes. The use of .tolist() and list methods like .index() inside the loop further contributes to the inefficiency. This implementation should be refactored to be vectorized over the batch dimension to meet the performance standards of vLLM. Consider using tensor operations to find indices and process modalities in parallel for all sequences in the batch.

if name == "feature_attention_mask":
dim = -1
if isinstance(mm_input, torch.Tensor):
return torch.concat(list(mm_input), dim=dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of _validate_and_reshape_mm_tensor seems to have a bug when handling a torch.Tensor. The line return torch.concat(list(mm_input), dim=dim) is problematic. When mm_input is a tensor, list(mm_input) iterates over its first dimension. torch.concat then joins these tensors along dim. For example, if mm_input has shape (B, C, L) and dim=1, the result will have shape (C, B*L), which is likely incorrect for batch processing where one would expect to flatten the batch dimension. This will likely cause shape mismatches in downstream processing.

multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
if len(multimodal_embeddings_multiscale) > 0:
deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in the shape calculation for deepstack_input_embeds. The second dimension is calculated as multiscale_len * inputs_embeds.size(1), which resolves to multiscale_len * text_config.hidden_size. However, this tensor is later populated with multimodal_embeddings_multiscale which have a feature dimension of multi_dim (multiscale_len * visual_dim), and then reshaped using visual_dim. This will raise a runtime error if text_config.hidden_size is not equal to visual_dim (vision_config.out_hidden_size). The correct size for the second dimension should be multi_dim (i.e., multiscale_len * visual_dim), which is computed a few lines above.

                    deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multi_dim)

None,
use_audio_in_video,
audio_feature_lengths,
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line creates a tensor in a highly inefficient way. torch.tensor(video_grid_thw) is redundant as video_grid_thw is already a tensor at this point. Creating a list of 1s and then converting it to a tensor is also inefficient. This can be simplified and made more performant.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(video_grid_thw.shape[0], dtype=torch.long, device=video_grid_thw.device))

Comment on lines 381 to 382
h_idxs = np.linspace(0, num_grid_per_side-1, h)
w_idxs = np.linspace(0, num_grid_per_side-1, w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function uses numpy for calculations (np.linspace), which can lead to performance bottlenecks due to CPU-GPU synchronization and data transfers. The comment on line 379 already indicates this. These operations should be replaced with their torch equivalents to keep the computation on the GPU and within the computation graph.

Suggested change
h_idxs = np.linspace(0, num_grid_per_side-1, h)
w_idxs = np.linspace(0, num_grid_per_side-1, w)
h_idxs = torch.linspace(0, num_grid_per_side-1, h, device=self.pos_embed.weight.device)
w_idxs = torch.linspace(0, num_grid_per_side-1, w, device=self.pos_embed.weight.device)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to finish this TODO before you have to go OOO?

Comment on lines 688 to 696
audio_token_indices = np.arange(next(iter([audio_len])))
curr_video_grid_thw = next(iter([video_grid_thw]))
height = curr_video_grid_thw[1] // spatial_merge_size
width = curr_video_grid_thw[2] // spatial_merge_size
video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
video_token_indices = np.broadcast_to(
video_token_indices, (video_token_indices.shape[0], height, width)
).reshape(-1)
video_token_indices = ((video_token_indices + shift) * next(iter([video_second_per_grid_t])) * position_id_per_seconds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function uses numpy for array creation and manipulation (np.arange, np.broadcast_to). This forces data transfers between CPU and GPU and can be a performance bottleneck. These should be replaced with torch equivalents to maintain performance.

        audio_token_indices = torch.arange(next(iter([audio_len])))
        curr_video_grid_thw = next(iter([video_grid_thw]))
        height = curr_video_grid_thw[1] // spatial_merge_size
        width = curr_video_grid_thw[2] // spatial_merge_size
        video_token_indices = torch.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
        video_token_indices = video_token_indices.expand(video_token_indices.shape[0], height, width).reshape(-1)
        video_token_indices = ((video_token_indices + shift) * next(iter([video_second_per_grid_t])) * position_id_per_seconds)

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you update tests/models/registry.py to be able to pass the CI?

Also please update the Supported Models page

@wangxiongts
Copy link
Contributor Author

wangxiongts commented Sep 24, 2025

Alright, I'll handle these parts. Currently, I'm still working on adding audio-in-video support in v1, In the meantime, One known issue is that I may not be able to straightforwardly reuse relevant modules from Qwen3-VL, because our model has already been made public, and some checkpoint keys and configurations are incompatible with Qwen3-VL. This stems from the fact that our internal iterations were not synchronized. This issue may require further careful discussion.

I might go on vacation starting tomorrow and probably won't resume modifications until after October 4th :) You can proceed with the review based on the current version.

@ywang96 ywang96 self-assigned this Sep 24, 2025
@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 24, 2025
None,
use_audio_in_video,
audio_feature_lengths,
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(len(video_grid_thw))

Simplify this

@CHNtentes
Copy link

Thanks for your work. May I ask, will talker model get supported in future? It seems Qwen2.5-Omni still only support thinker model now.

@Wesley-Jzy
Copy link

LGTM! May I know whether Talker model will be supported by vLLM?

@ywang96
Copy link
Member

ywang96 commented Sep 24, 2025

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

@Wesley-Jzy
Copy link

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

@ywang96
Copy link
Member

ywang96 commented Sep 24, 2025

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

@Wesley-Jzy
Copy link

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

Great! And may I know will the new project also handle the single-model multimodal models such as Kimi-Audio? Or they will be supported by vLLM?

@CHNtentes
Copy link

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

wenbinc-Bin added a commit to wenbinc-Bin/vllm-fork that referenced this pull request Sep 25, 2025
@eschmidbauer
Copy link

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

Same, even with flash-attn2 it is very slow

@CHNtentes
Copy link

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

Same, even with flash-attn2 it is very slow

I tried this PR and it's like >20x faster than transformers :)

@facebook-github-bot
Copy link

@houseroad has imported this pull request. If you are a Meta employee, you can view this in D83274891.

@ywang96 ywang96 mentioned this pull request Sep 29, 2025
1 task
@UmutAlihan
Copy link

Looking very forward to be able to inference using vLLM for this omni model. Compiling vllm from specific branch source is really hustling on some limited access corporate standalone servers 😿.

@ywang96
Copy link
Member

ywang96 commented Oct 3, 2025

FYI I'm back to working on this PR

@tensorboy
Copy link

FYI I'm back to working on this PR

you are the best!

@fikrikarim
Copy link

Thank you for all the hard work. We really appreciate it.

From what I understand this PR is only for the thinker model and it only supports text output.

Is there any rough timeline when audio output will be added? Thank you so much!

@ywang96
Copy link
Member

ywang96 commented Oct 5, 2025

This PR should be again functional - I probably won't fix the audio_in_video issue and will have it fixed together in #26156 so should be able to merge it pretty soon

@ywang96
Copy link
Member

ywang96 commented Oct 5, 2025

Thank you for all the hard work. We really appreciate it.

From what I understand this PR is only for the thinker model and it only supports text output.

Is there any rough timeline when audio output will be added? Thank you so much!

We don't have a rough timeline yet but hopefully by the end of this year/early next year

@fikrikarim
Copy link

We don't have a rough timeline yet but hopefully by the end of this year/early next year

Got it. Thanks!

Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wangxiongts.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
@ai-and-i
Copy link

ai-and-i commented Oct 7, 2025

Hi, thanks for the great work on this PR! I tried to run it and it works great when providing (one of audio/image/video)+text or image+video+text. However, when I'm running it with audio+image+text or audio+video+text, it crashes. I made a gist with a small example: https://gist.github.com/ai-and-i/76b75f1bef2f2df6b1ea5998c3911918 (I tested it in a jupyter notebook with
a985baa).

Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--25550.org.readthedocs.build/en/25550/

@aldazero
Copy link

aldazero commented Oct 8, 2025

Hi, thanks for the great work on this. Does this support audio output? @wangxiongts

@mergify mergify bot removed the needs-rebase label Oct 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new-model Requests to new models qwen Related to Qwen models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Support Qwen3-Omni-30B