-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Add Qwen3-Omni moe thinker #25550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Qwen3-Omni moe thinker #25550
Changes from 2 commits
9fe9994
93efc39
81fd24b
f0d057a
c3e15a6
d59ac08
0b24c98
087a936
8ffc26e
fb1d82b
7f42fb0
7408b9c
8e1f5aa
3c44f89
a985baa
0525d27
7484970
3c6243b
b8ec4d6
4c749d1
650855a
03e1310
7796103
2dee5f6
a6cb680
b4137ab
d82c17e
677412d
bb00572
24c2c95
14c6903
51366bd
dc17c61
ae0c930
3ec407f
897245a
2465f49
5d7ddf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1028,6 +1028,214 @@ | |||||||||
|
|
||||||||||
| return llm_positions, mrope_position_delta | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _omni3_get_input_positions_tensor( | ||||||||||
| cls, | ||||||||||
| config, | ||||||||||
| input_ids: Optional[torch.LongTensor] = None, | ||||||||||
| image_grid_thw: Optional[torch.LongTensor] = None, | ||||||||||
| video_grid_thw: Optional[torch.LongTensor] = None, | ||||||||||
| attention_mask: Optional[torch.Tensor] = None, | ||||||||||
| use_audio_in_video: bool = False, | ||||||||||
| audio_seqlens: Optional[torch.LongTensor] = None, | ||||||||||
| second_per_grids: Optional[torch.Tensor] = None, | ||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||
|
Comment on lines
1123
to
1132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function |
||||||||||
|
|
||||||||||
| def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): | ||||||||||
| input_lengths_leave = input_lengths % 100 | ||||||||||
| feat_lengths = (input_lengths_leave - 1) // 2 + 1 | ||||||||||
| output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 | ||||||||||
| return output_lengths | ||||||||||
| spatial_merge_size = config.vision_config.spatial_merge_size | ||||||||||
| image_token_id = config.image_token_id | ||||||||||
| video_token_id = config.video_token_id | ||||||||||
| audio_token_id = config.audio_token_id | ||||||||||
| vision_start_token_id = config.vision_start_token_id | ||||||||||
| audio_start_token_id = config.audio_start_token_id | ||||||||||
| position_id_per_seconds = config.position_id_per_seconds | ||||||||||
| mrope_position_deltas = [] | ||||||||||
| if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): | ||||||||||
| total_input_ids = input_ids | ||||||||||
| if attention_mask is None: | ||||||||||
| attention_mask = torch.ones_like(total_input_ids) | ||||||||||
| position_ids = torch.zeros( | ||||||||||
| 3, | ||||||||||
| input_ids.shape[0], | ||||||||||
| input_ids.shape[1], | ||||||||||
| dtype=input_ids.dtype, | ||||||||||
| device=input_ids.device, | ||||||||||
| ) | ||||||||||
| image_idx, video_idx, audio_idx = 0, 0, 0 | ||||||||||
| attention_mask = attention_mask.to(total_input_ids.device) | ||||||||||
| for i, input_ids in enumerate(total_input_ids): | ||||||||||
| input_ids = input_ids[attention_mask[i] == 1] | ||||||||||
| image_nums, video_nums, audio_nums = 0, 0, 0 | ||||||||||
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) | ||||||||||
| vision_tokens = input_ids[vision_start_indices + 1] | ||||||||||
| audio_nums = torch.sum(input_ids == audio_start_token_id) | ||||||||||
| image_nums = (vision_tokens == image_token_id).sum() | ||||||||||
| video_nums = ( | ||||||||||
| (vision_tokens == audio_start_token_id).sum() | ||||||||||
| if use_audio_in_video | ||||||||||
| else (vision_tokens == video_token_id).sum() | ||||||||||
| ) | ||||||||||
| input_tokens = input_ids.tolist() | ||||||||||
| llm_pos_ids_list: list = [] | ||||||||||
| st = 0 | ||||||||||
| remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums | ||||||||||
| multimodal_nums = ( | ||||||||||
| image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums | ||||||||||
| ) | ||||||||||
| for _ in range(multimodal_nums): | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| if (image_token_id in input_tokens or video_token_id in input_tokens) and ( | ||||||||||
| remain_videos > 0 or remain_images > 0 | ||||||||||
| ): | ||||||||||
| ed_vision_start = input_tokens.index(vision_start_token_id, st) | ||||||||||
| else: | ||||||||||
| ed_vision_start = len(input_tokens) + 1 | ||||||||||
| if audio_token_id in input_tokens and remain_audios > 0: | ||||||||||
| ed_audio_start = input_tokens.index(audio_start_token_id, st) | ||||||||||
| else: | ||||||||||
| ed_audio_start = len(input_tokens) + 1 | ||||||||||
| min_ed = min(ed_vision_start, ed_audio_start) | ||||||||||
| if min_ed == ed_audio_start: | ||||||||||
| text_len = min_ed - st | ||||||||||
| if text_len != 0: | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| bos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) | ||||||||||
| llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx | ||||||||||
| llm_pos_ids_list.append(llm_pos_ids) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| eos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st += text_len + bos_len + audio_len + eos_len | ||||||||||
| audio_idx += 1 | ||||||||||
| remain_audios -= 1 | ||||||||||
| elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id: | ||||||||||
| text_len = min_ed - st | ||||||||||
| if text_len != 0: | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| bos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| grid_t = image_grid_thw[image_idx][0] | ||||||||||
| grid_hs = image_grid_thw[:, 1] | ||||||||||
| grid_ws = image_grid_thw[:, 2] | ||||||||||
| t_index = ((torch.arange(grid_t)) * 1 * position_id_per_seconds) | ||||||||||
| llm_pos_ids = cls._get_llm_pos_ids_for_vision( | ||||||||||
| st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws | ||||||||||
| ) | ||||||||||
| image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) | ||||||||||
| llm_pos_ids_list.append(llm_pos_ids) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| eos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st += text_len + bos_len + image_len + eos_len | ||||||||||
| image_idx += 1 | ||||||||||
| remain_images -= 1 | ||||||||||
| elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id and not use_audio_in_video: | ||||||||||
| text_len = min_ed - st | ||||||||||
| if text_len != 0: | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| bos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| grid_t = video_grid_thw[video_idx][0] | ||||||||||
| grid_hs = video_grid_thw[:, 1] | ||||||||||
| grid_ws = video_grid_thw[:, 2] | ||||||||||
| t_index = ( | ||||||||||
| (torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds | ||||||||||
| ) | ||||||||||
| llm_pos_ids = cls._get_llm_pos_ids_for_vision( | ||||||||||
| st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws | ||||||||||
| ) | ||||||||||
| video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) | ||||||||||
| llm_pos_ids_list.append(llm_pos_ids) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| eos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st += text_len + bos_len + video_len + eos_len | ||||||||||
| video_idx += 1 | ||||||||||
| remain_videos -= 1 | ||||||||||
| elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video: | ||||||||||
| text_len = min_ed - st | ||||||||||
| if text_len != 0: | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| bos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) | ||||||||||
| audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx | ||||||||||
| grid_t = video_grid_thw[video_idx][0] | ||||||||||
| grid_hs = video_grid_thw[:, 1] | ||||||||||
| grid_ws = video_grid_thw[:, 2] | ||||||||||
| t_index = ( | ||||||||||
| (torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds | ||||||||||
| ) | ||||||||||
| video_llm_pos_ids = cls._get_llm_pos_ids_for_vision( | ||||||||||
| st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws | ||||||||||
| ) | ||||||||||
| video_data_index, audio_data_index = 0, 0 | ||||||||||
| while ( | ||||||||||
| video_data_index < video_llm_pos_ids.shape[-1] | ||||||||||
| and audio_data_index < audio_llm_pos_ids.shape[-1] | ||||||||||
| ): | ||||||||||
| if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]: | ||||||||||
| llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1]) | ||||||||||
| video_data_index += 1 | ||||||||||
| else: | ||||||||||
| llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1]) | ||||||||||
| audio_data_index += 1 | ||||||||||
| if video_data_index < video_llm_pos_ids.shape[-1]: | ||||||||||
| llm_pos_ids_list.append( | ||||||||||
| video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]] | ||||||||||
| ) | ||||||||||
| if audio_data_index < audio_llm_pos_ids.shape[-1]: | ||||||||||
| llm_pos_ids_list.append( | ||||||||||
| audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]] | ||||||||||
| ) | ||||||||||
| video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| eos_len = 1 | ||||||||||
| llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 | ||||||||||
| audio_idx += 1 | ||||||||||
| video_idx += 1 | ||||||||||
| remain_videos -= 1 | ||||||||||
| remain_audios -= 1 | ||||||||||
| if st < len(input_tokens): | ||||||||||
| st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||||||||||
| text_len = len(input_tokens) - st | ||||||||||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||||||||||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||||||||||
|
|
||||||||||
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) | ||||||||||
| mrope_position_deltas.append(llm_positions.long().max() + 1 - len(input_ids)) | ||||||||||
| mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) | ||||||||||
| return position_ids, mrope_position_deltas.long() | ||||||||||
| else: | ||||||||||
| position_ids = attention_mask.cumsum(-1) - 1 | ||||||||||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||||||||||
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) | ||||||||||
| max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] | ||||||||||
| mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) | ||||||||||
| return position_ids, mrope_position_deltas.long() | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _omni_get_input_positions_tensor( | ||||||||||
| cls, | ||||||||||
|
|
@@ -1060,7 +1268,27 @@ | |||||||||
| # TODO(fyabc): refactor and share more code with | ||||||||||
| # _vl_get_input_positions_tensor. | ||||||||||
|
|
||||||||||
| model_type = hf_config.model_type | ||||||||||
| thinker_config = hf_config.thinker_config | ||||||||||
|
|
||||||||||
| if isinstance(image_grid_thw, list): | ||||||||||
| image_grid_thw = torch.tensor(image_grid_thw) | ||||||||||
| if isinstance(video_grid_thw, list): | ||||||||||
| video_grid_thw = torch.tensor(video_grid_thw) | ||||||||||
|
|
||||||||||
| if "qwen3_omni" in model_type: | ||||||||||
| llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor(thinker_config, | ||||||||||
| torch.tensor([input_tokens]), | ||||||||||
| image_grid_thw, | ||||||||||
| video_grid_thw, | ||||||||||
| None, | ||||||||||
| use_audio_in_video, | ||||||||||
| audio_feature_lengths, | ||||||||||
| torch.tensor([1] * torch.tensor(video_grid_thw).shape[0])) | ||||||||||
|
||||||||||
| torch.tensor([1] * torch.tensor(video_grid_thw).shape[0])) | |
| torch.ones(video_grid_thw.shape[0], dtype=torch.long, device=video_grid_thw.device)) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| torch.tensor([1] * torch.tensor(video_grid_thw).shape[0])) | |
| torch.ones(len(video_grid_thw)) |
Simplify this
Uh oh!
There was an error while loading. Please reload this page.