diff --git a/requirements/common.txt b/requirements/common.txt index f537b3aab541..ae095f410631 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -47,3 +47,5 @@ opentelemetry-sdk>=1.26.0 # vllm.tracing opentelemetry-api>=1.26.0 # vllm.tracing opentelemetry-exporter-otlp>=1.26.0 # vllm.tracing opentelemetry-semantic-conventions-ai>=0.4.1 # vllm.tracing +numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding & Qwen2/2.5-VL/Omni +numba == 0.61.2; python_version > '3.9' diff --git a/requirements/cuda.txt b/requirements/cuda.txt index a71d9728f38a..16c18ac79cbe 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,9 +1,6 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' - # Dependencies for NVIDIA GPUs ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1. torch==2.7.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 8a84f2ff1ed0..1b79ee3cdaec 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,9 +1,6 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' - # Dependencies for AMD GPUs boto3 botocore diff --git a/tests/model_executor/test_mrope_positions.py b/tests/model_executor/test_mrope_positions.py new file mode 100644 index 000000000000..900133cef20e --- /dev/null +++ b/tests/model_executor/test_mrope_positions.py @@ -0,0 +1,680 @@ +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +import pytest +import torch + +from vllm.model_executor.layers.mrope_positions import ( + mrope_assign_next_input_positions, mrope_get_input_positions_and_delta, + mrope_get_next_input_positions_tensor) + +IMAGE = 101 +VIDEO = 102 +AUDIO = 103 +VISION_START = 201 +VISION_END = 202 +AUDIO_START = 203 +AUDIO_END = 204 + +SPATIAL_MERGE_SIZE_2 = 2 + + +def create_image(t: int, h: int, w: int, merge_size: int): + return [ + VISION_START + ] + [IMAGE] * t * (h // merge_size) * (w // merge_size) + [VISION_END] + + +def create_video(t: int, h: int, w: int, merge_size: int): + return [ + VISION_START + ] + [VIDEO] * t * (h // merge_size) * (w // merge_size) + [VISION_END] + + +def create_audio(audio_feature_length: int): + audio_token_num = (((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1) + return [AUDIO_START] + [AUDIO] * audio_token_num + [AUDIO_END] + + +def create_video_with_audio(num_t: int, num_h: int, num_w: int, + merge_size: int, audio_feature_length: int, + t_ntoken_per_chunk: int, tokens_per_grid_t: float): + audio_token_num = (((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1) + added_audio_token_num = 0 + + ret = [VISION_START, AUDIO_START] + next_chunk_t = t_ntoken_per_chunk + + for t in range(num_t): + video_t = int(t * tokens_per_grid_t) + + # audio tokens + if video_t >= next_chunk_t: + next_chunk_t += t_ntoken_per_chunk + if added_audio_token_num < audio_token_num: + chunked_audio_token_num = min( + t_ntoken_per_chunk, + audio_token_num - added_audio_token_num) + ret.extend([AUDIO] * chunked_audio_token_num) + added_audio_token_num += chunked_audio_token_num + + # video tokens + ret.extend([VIDEO] * (num_h // merge_size * num_w // merge_size)) + + # remaining audio tokens + if added_audio_token_num < audio_token_num: + ret.extend([AUDIO] * (audio_token_num - added_audio_token_num)) + + return ret + [AUDIO_END, VISION_END] + + +vl_test_cases = [ + # text and image and video + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "image", + "t": 1, + "h": 4 * SPATIAL_MERGE_SIZE_2, + "w": 6 * SPATIAL_MERGE_SIZE_2, + }, + { + "type": "video", + "t": 4, + "h": 8 * SPATIAL_MERGE_SIZE_2, + "w": 12 * SPATIAL_MERGE_SIZE_2, + "second_per_grid_t": 1.0, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + "tokens_per_second": + 1.0, + }, + # text and image + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "image", + "t": 1, + "h": 2 * SPATIAL_MERGE_SIZE_2, + "w": 2 * SPATIAL_MERGE_SIZE_2, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + "tokens_per_second": + 1.0, + }, + # text and video + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "video", + "t": 4, + "h": 6 * SPATIAL_MERGE_SIZE_2, + "w": 8 * SPATIAL_MERGE_SIZE_2, + "second_per_grid_t": 1.0, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + }, + # text only + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + ], + "spatial_merge_size": SPATIAL_MERGE_SIZE_2, + }, +] + + +def make_vl_hf_config(spatial_merge_size=2, tokens_per_second=2): + hf_config = Mock() + hf_config.image_token_id = IMAGE + hf_config.video_token_id = VIDEO + hf_config.vision_start_token_id = VISION_START + hf_config.vision_end_token_id = VISION_END + + hf_config.vision_config = Mock() + hf_config.vision_config.spatial_merge_size = spatial_merge_size + hf_config.vision_config.tokens_per_second = tokens_per_second + + hf_config.vision_config.rope_scaling = { + "mrope_section": [16, 24, 24], + } + hf_config.thinker_config = None + + return hf_config + + +@pytest.mark.parametrize("test_case", vl_test_cases) +def test_vl_get_input_positions_and_delta_correctness(test_case): + input = test_case["input"] + spatial_merge_size = test_case["spatial_merge_size"] + tokens_per_second = test_case.get("tokens_per_second", 1.0) + + hf_config = make_vl_hf_config(spatial_merge_size, tokens_per_second) + + input_tokens = [] + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + + for item in input: + if item["type"] == "tokens": + input_tokens.extend(item["tokens"]) + elif item["type"] == "image": + input_tokens.extend( + create_image(item["t"], item["h"], item["w"], + spatial_merge_size)) + image_grid_thw.append([item["t"], item["h"], item["w"]]) + elif item["type"] == "video": + input_tokens.extend( + create_video(item["t"], item["h"], item["w"], + spatial_merge_size)) + video_grid_thw.append([item["t"], item["h"], item["w"]]) + second_per_grid_ts.append(item["second_per_grid_t"]) + + input_positions_torch, mrope_position_delta_torch = \ + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_numba=False, + ) + + input_positions_numba, mrope_position_delta_numba = \ + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_numba=True, + ) + assert input_positions_torch.dtype == input_positions_numba.dtype + assert input_positions_torch.shape == input_positions_numba.shape + + assert torch.equal(input_positions_torch, input_positions_numba) + assert mrope_position_delta_torch == mrope_position_delta_numba + + +omni_test_cases = [ + # text and image and video (with aduio) + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "image", + "t": 1, + "h": 2 * SPATIAL_MERGE_SIZE_2, + "w": 4 * SPATIAL_MERGE_SIZE_2, + }, + { + "type": "video", + "t": 4, + "h": 2 * SPATIAL_MERGE_SIZE_2, + "w": 4 * SPATIAL_MERGE_SIZE_2, + "second_per_grid_t": 1.0, + "audio_feature_length": 50, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + "use_audio_with_video": + True, + }, + # text and image and video + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "image", + "t": 1, + "h": 4 * SPATIAL_MERGE_SIZE_2, + "w": 6 * SPATIAL_MERGE_SIZE_2, + }, + { + "type": "video", + "t": 4, + "h": 8 * SPATIAL_MERGE_SIZE_2, + "w": 12 * SPATIAL_MERGE_SIZE_2, + "second_per_grid_t": 1.0, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + "tokens_per_second": + 1.0, + }, + # text and image + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "image", + "t": 1, + "h": 2 * SPATIAL_MERGE_SIZE_2, + "w": 2 * SPATIAL_MERGE_SIZE_2, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + "tokens_per_second": + 1.0, + }, + # text and video + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "video", + "t": 4, + "h": 2 * SPATIAL_MERGE_SIZE_2, + "w": 3 * SPATIAL_MERGE_SIZE_2, + "second_per_grid_t": 1.0, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + }, + # text and audio + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + { + "type": "audio", + "audio_feature_length": 144, + }, + { + "type": "tokens", + "tokens": [4, 5, 6, 7, 8], + }, + ], + "spatial_merge_size": + SPATIAL_MERGE_SIZE_2, + }, + # text only + { + "input": [ + { + "type": "tokens", + "tokens": [0, 1, 2, 3], + }, + ], + "spatial_merge_size": SPATIAL_MERGE_SIZE_2, + }, +] + + +def make_omni_hf_config( + spatial_merge_size=2, + tokens_per_second=25, + seconds_per_chunk=2.0, +): + hf_config = Mock() + + hf_config.thinker_config = Mock() + hf_config.thinker_config.image_token_index = IMAGE + hf_config.thinker_config.video_token_index = VIDEO + hf_config.thinker_config.audio_token_index = AUDIO + hf_config.thinker_config.vision_start_token_id = VISION_START + hf_config.thinker_config.vision_end_token_id = VISION_END + hf_config.thinker_config.audio_start_token_id = AUDIO_START + hf_config.thinker_config.audio_end_token_id = AUDIO_END + + hf_config.thinker_config.seconds_per_chunk = seconds_per_chunk + + vision_config = Mock() + hf_config.thinker_config.vision_config = vision_config + vision_config.spatial_merge_size = spatial_merge_size + vision_config.tokens_per_second = tokens_per_second + + hf_config.thinker_config.text_config = Mock() + hf_config.thinker_config.text_config.rope_scaling = { + "mrope_section": [16, 24, 24], + } + + return hf_config + + +@pytest.mark.parametrize("test_case", omni_test_cases) +def test_omni_get_input_positions_and_delta_correctness(test_case): + input = test_case["input"] + spatial_merge_size = test_case["spatial_merge_size"] + use_audio_with_video = test_case.get("use_audio_with_video", False) + tokens_per_second = test_case.get("tokens_per_second", 25) + seconds_per_chunk = test_case.get("seconds_per_chunk", 2.0) + + hf_config = make_omni_hf_config( + spatial_merge_size, + tokens_per_second, + seconds_per_chunk, + ) + + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + + input_tokens = [] + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + + for item in input: + if item["type"] == "tokens": + input_tokens.extend(item["tokens"]) + elif item["type"] == "image": + input_tokens.extend( + create_image(item["t"], item["h"], item["w"], + spatial_merge_size)) + image_grid_thw.append([item["t"], item["h"], item["w"]]) + elif item["type"] == "audio": + input_tokens.extend(create_audio(item["audio_feature_length"])) + audio_feature_lengths.append(item["audio_feature_length"]) + elif item["type"] == "video": + if use_audio_with_video: + tokens_per_grid_t = tokens_per_second * item[ + "second_per_grid_t"] + input_tokens.extend( + create_video_with_audio(item["t"], item["h"], item["w"], + spatial_merge_size, + item["audio_feature_length"], + t_ntoken_per_chunk, + tokens_per_grid_t)) + audio_feature_lengths.append(item["audio_feature_length"]) + else: + input_tokens.extend( + create_video(item["t"], item["h"], item["w"], + spatial_merge_size)) + video_grid_thw.append([item["t"], item["h"], item["w"]]) + second_per_grid_ts.append(item["second_per_grid_t"]) + + audio_feature_lengths = torch.tensor(audio_feature_lengths, + dtype=torch.int64) + + input_positions_torch, mrope_position_delta_torch = \ + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_with_video, + use_numba=False, + ) + + input_positions_numba, mrope_position_delta_numba = \ + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_with_video, + use_numba=True, + ) + assert input_positions_torch.dtype == input_positions_numba.dtype + assert input_positions_torch.shape == input_positions_numba.shape + + assert torch.equal(input_positions_torch, input_positions_numba) + assert mrope_position_delta_torch == mrope_position_delta_numba + + +@pytest.mark.parametrize("is_omni, modality", [ + (True, "image_grid_thw"), + (True, "video_grid_thw"), + (True, "audio_feature_lengths"), + (False, "image_grid_thw"), + (False, "video_grid_thw"), +]) +def test_missing_mm_item_error(is_omni, modality): + hf_config = make_omni_hf_config() if is_omni else make_vl_hf_config() + input_tokens = [1, 2, 3, 4] + image_grid_thw: list[list[int]] = [] + video_grid_thw: list[list[int]] = [] + second_per_grid_ts: list[float] = [] + audio_feature_lengths: list[int] = [] + if modality == "image_grid_thw": + input_tokens.extend( + [VISION_START, IMAGE, IMAGE, IMAGE, IMAGE, VISION_END]) + elif modality == "video_grid_thw": + if is_omni: + input_tokens.extend( + [VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, AUDIO, VISION_END]) + else: + input_tokens.extend( + [VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, VISION_END]) + elif modality == "audio_feature_lengths": + input_tokens.extend( + [AUDIO_START, AUDIO, AUDIO, AUDIO, AUDIO, AUDIO_END]) + + with pytest.raises(ValueError) as exc_info: + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=torch.tensor(audio_feature_lengths, + dtype=torch.float64), + use_audio_in_video=is_omni, + use_numba=True, + ) + + assert f"{modality}[0] is missing" in str(exc_info.value) + + +@pytest.mark.parametrize("is_omni, modality", [ + (True, "image_grid_thw"), + (True, "video_grid_thw"), + (True, "audio_feature_lengths"), + (False, "image_grid_thw"), + (False, "video_grid_thw"), +]) +def test_tokens_out_of_bound_error(is_omni, modality): + hf_config = make_omni_hf_config() if is_omni else make_vl_hf_config() + input_tokens = [1, 2, 3, 4] + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + if modality == "image_grid_thw": + input_tokens.extend( + [VISION_START, IMAGE, IMAGE, IMAGE, IMAGE, VISION_END]) + image_grid_thw.append([1, 8, 8]) + elif modality == "video_grid_thw": + video_grid_thw.append([1, 8, 8]) + if is_omni: + audio_feature_lengths.append(1000) + input_tokens.extend( + [VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, AUDIO, VISION_END]) + else: + input_tokens.extend( + [VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, VISION_END]) + second_per_grid_ts.append(1.0) + elif modality == "audio_feature_lengths": + audio_feature_lengths.append(1000) + input_tokens.extend( + [AUDIO_START, AUDIO, AUDIO, AUDIO, AUDIO, AUDIO_END]) + + with pytest.raises(ValueError) as exc_info: + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=torch.tensor(audio_feature_lengths, + dtype=torch.float64), + use_audio_in_video=is_omni, + use_numba=True, + ) + assert f"input_tokens out of bounds while processing {modality}[0]" in str( + exc_info.value) + + +@pytest.mark.parametrize("is_omni, modality", [ + (True, "image_grid_thw"), + (True, "video_grid_thw"), + (True, "audio_feature_lengths"), + (False, "image_grid_thw"), + (False, "video_grid_thw"), +]) +def test_unused_mm_items_error(is_omni, modality): + hf_config = make_omni_hf_config() if is_omni else make_vl_hf_config() + input_tokens = [1, 2, 3, 4] + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + if modality == "image_grid_thw": + image_grid_thw.append([1, 4, 4]) + image_grid_thw.append([1, 4, 4]) + input_tokens.extend( + [VISION_START, IMAGE, IMAGE, IMAGE, IMAGE, VISION_END]) + elif modality == "video_grid_thw": + video_grid_thw.append([1, 4, 4]) + video_grid_thw.append([1, 4, 4]) + if is_omni: + audio_feature_lengths.append(16) + audio_feature_lengths.append(16) + input_tokens.extend([ + VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, AUDIO, AUDIO, AUDIO, + AUDIO, VISION_END + ]) + else: + input_tokens.extend( + [VISION_START, VIDEO, VIDEO, VIDEO, VIDEO, VISION_END]) + second_per_grid_ts.append(1.0) + second_per_grid_ts.append(1.0) + elif modality == "audio_feature_lengths": + audio_feature_lengths.append(16) + audio_feature_lengths.append(16) + input_tokens.extend( + [AUDIO_START, AUDIO, AUDIO, AUDIO, AUDIO, AUDIO_END]) + + with pytest.raises(ValueError) as exc_info: + mrope_get_input_positions_and_delta( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=torch.tensor(audio_feature_lengths, + dtype=torch.float64), + use_audio_in_video=is_omni, + use_numba=True, + ) + + assert f"{modality} has 1 unused item" in str(exc_info.value) + + +@pytest.mark.parametrize( + "mrope_position_delta, context_len, seq_len, expected_output", [ + (0, 0, 1, [[0], [0], [0]]), + (0, 5, 7, [[5, 6], [5, 6], [5, 6]]), + (-10, 160, 163, [[150, 151, 152], [150, 151, 152], [150, 151, 152]]), + (-10, 200, 201, [[190], [190], [190]]), + ]) +def test_mrope_get_next_input_positions_tensor(mrope_position_delta, + context_len, seq_len, + expected_output): + input_positions = mrope_get_next_input_positions_tensor( + mrope_position_delta=mrope_position_delta, + context_len=context_len, + seq_len=seq_len, + ) + + assert torch.equal(input_positions, + torch.tensor(expected_output, dtype=torch.int64)) + + +@pytest.mark.parametrize( + "mrope_position_delta, out_offset, context_len, seq_len, expected_output", + [ + (0, 0, 0, 1, [[0], [0], [0]]), + (0, 1, 5, 7, [[0, 5, 6], [0, 5, 6], [0, 5, 6]]), + (-10, 2, 160, 163, [[0, 0, 150, 151, 152], [0, 0, 150, 151, 152], + [0, 0, 150, 151, 152]]), + (-10, 4, 200, 201, [[0, 0, 0, 0, 190], [0, 0, 0, 0, 190], + [0, 0, 0, 0, 190]]), + ]) +def test_mrope_assign_next_input_positions(mrope_position_delta, out_offset, + context_len, seq_len, + expected_output): + out = torch.zeros((3, out_offset + seq_len - context_len), + dtype=torch.int64) + out_np = out.numpy() + mrope_assign_next_input_positions( + out=out_np, + out_offset=out_offset, + mrope_position_delta=mrope_position_delta, + context_len=context_len, + num_new_tokens=seq_len - context_len, + ) + + assert torch.equal(out, torch.tensor(expected_output, dtype=torch.int64)) diff --git a/vllm/model_executor/layers/mrope_positions.py b/vllm/model_executor/layers/mrope_positions.py new file mode 100644 index 000000000000..da1f8f9b0506 --- /dev/null +++ b/vllm/model_executor/layers/mrope_positions.py @@ -0,0 +1,1016 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional, Tuple, Union + +import numba +import numpy as np +import torch +from transformers import PretrainedConfig + + +def mrope_get_input_positions_and_delta( + input_tokens: Union[list[int], np.ndarray], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + use_numba: bool = True, +) -> tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + is_omni = thinker_uses_mrope(hf_config) + + if use_numba: + input_tokens = np.asarray(input_tokens, dtype=np.int64) + + if image_grid_thw is None or len(image_grid_thw) == 0: + image_grid_thw = np.empty((0, 3), dtype=np.int64) + elif isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.numpy() + else: + image_grid_thw = np.array(image_grid_thw, dtype=np.int64) + + if video_grid_thw is None or len(video_grid_thw) == 0: + video_grid_thw = np.empty((0, 3), dtype=np.int64) + elif isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.numpy() + else: + video_grid_thw = np.array(video_grid_thw, dtype=np.int64) + + if second_per_grid_ts is None: + second_per_grid_ts_np = np.empty((0, ), dtype=np.float64) + else: + second_per_grid_ts_np = np.array(second_per_grid_ts, + dtype=np.float64) + + if is_omni: + if audio_feature_lengths is None: + audio_feature_lengths = np.empty((0, ), dtype=np.int64) + else: + audio_feature_lengths = audio_feature_lengths.numpy() + + thinker_config = hf_config.thinker_config + ( + input_positions, + mrope_position_delta, + ) = _omni_get_input_positions_numba( + input_tokens=input_tokens, + image_token_id=int(thinker_config.image_token_index), + video_token_id=int(thinker_config.video_token_index), + audio_token_id=int(thinker_config.audio_token_index), + vision_start_token_id=int( + thinker_config.vision_start_token_id), + vision_end_token_id=int(thinker_config.vision_end_token_id), + audio_start_token_id=int(thinker_config.audio_start_token_id), + audio_end_token_id=int(thinker_config.audio_end_token_id), + spatial_merge_size=int( + thinker_config.vision_config.spatial_merge_size), + tokens_per_second=float( + thinker_config.vision_config.tokens_per_second), + seconds_per_chunk=float(thinker_config.seconds_per_chunk), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts_np, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + ( + input_positions, + mrope_position_delta, + ) = _vl_get_input_positions_numba( + input_tokens=input_tokens, + image_token_id=int(hf_config.image_token_id), + video_token_id=int(hf_config.video_token_id), + spatial_merge_size=int( + hf_config.vision_config.spatial_merge_size), + tokens_per_second=float( + getattr(hf_config.vision_config, "tokens_per_second", + 1.0)), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts_np, + ) + + input_positions = torch.from_numpy(input_positions) + if context_len != 0 or seq_len is not None: + input_positions = input_positions[:, context_len:seq_len] + else: + if isinstance(input_tokens, np.ndarray): + input_tokens = input_tokens.tolist() + + if is_omni: + ( + input_positions, + mrope_position_delta, + ) = _omni_get_input_positions_torch( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + ( + input_positions, + mrope_position_delta, + ) = _vl_get_input_positions_torch( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + return input_positions, mrope_position_delta + + +def mrope_get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, +) -> List[List[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + +def mrope_get_next_input_positions_tensor( + mrope_position_delta: int, + context_len: int, + seq_len: int, +) -> torch.Tensor: + return torch.arange( + mrope_position_delta + context_len, + mrope_position_delta + seq_len, + ).expand(3, -1) + + +@numba.jit(nopython=True) +def mrope_assign_next_input_positions( + out: np.ndarray, + out_offset: int, + mrope_position_delta: int, + context_len: int, + num_new_tokens: int, +): + for dim in range(3): + for idx in range(num_new_tokens): + out[dim, + out_offset + idx] = mrope_position_delta + context_len + idx + + +def omni_get_updates_use_audio_in_video( + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[List[int], torch.Tensor], + video_second_per_grid_t: float, +) -> List[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = _split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + + +def _vl_get_input_positions_torch( + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]], + context_len: int = 0, + seq_len: Optional[int] = None, +) -> Tuple[torch.Tensor, int]: + """ + Get mrope input positions and delta value for Qwen2/2.5-VL + + This is the original PyTorch implementation + """ + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", + 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + if context_len != 0 or seq_len is not None: + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + +@numba.jit(nopython=True) +def _vl_get_input_positions_numba( + input_tokens: np.ndarray, + image_token_id: int, + video_token_id: int, + spatial_merge_size: int, + tokens_per_second: float, + image_grid_thw: np.ndarray, + video_grid_thw: np.ndarray, + second_per_grid_ts: np.ndarray, +) -> tuple[np.ndarray, int]: + """ + Get mrope input positions and delta value for Qwen2/2.5-VL + + This is the optimized numba implementation + """ + + mrope_pos = np.empty((3, len(input_tokens)), dtype=np.int64) + + cur_t = -1 + + cur_image_idx = -1 + cur_video_idx = -1 + + i = 0 + while i < len(input_tokens): + token_id = input_tokens[i] + if token_id == image_token_id: + cur_image_idx += 1 + if cur_image_idx >= len(image_grid_thw): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_IMAGE, cur_image_idx) + + i, cur_t = _emit_image_tokens( + mrope_pos, + i=i, + image_grid_thw=image_grid_thw[cur_image_idx], + start_t=cur_t + 1, + spatial_merge_size=spatial_merge_size, + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_IMAGE, + cur_image_idx) + elif token_id == video_token_id: + cur_video_idx += 1 + if cur_video_idx >= len(video_grid_thw): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_VIDEO, cur_video_idx) + + i, cur_t = _emit_video_tokens( + mrope_pos, + i=i, + video_grid_thw=video_grid_thw[cur_video_idx], + start_t=cur_t + 1, + spatial_merge_size=spatial_merge_size, + tokens_per_second=tokens_per_second, + second_per_grid_t=1.0 \ + if cur_video_idx >= len(second_per_grid_ts) \ + else second_per_grid_ts[cur_video_idx], + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_VIDEO, + cur_video_idx) + else: + cur_t += 1 + i = _emit_1d_token( + mrope_pos, + i=i, + t=cur_t, + ) + + num_unused_images = len(image_grid_thw) - cur_image_idx - 1 + if num_unused_images > 0: + with numba.objmode(): + _raise_unused_mm_items_error(MM_TYPE_IMAGE, num_unused_images) + + num_unused_videos = len(video_grid_thw) - cur_video_idx - 1 + if num_unused_videos > 0: + with numba.objmode(): + _raise_unused_mm_items_error(MM_TYPE_VIDEO, num_unused_videos) + + mrope_position_delta = cur_t + 1 - len(input_tokens) + return mrope_pos, mrope_position_delta + + +def _omni_get_input_positions_torch( + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, +) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_torch. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = _get_llm_pos_ids_for_vision(start_idx, image_idx, + spatial_merge_size, + t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = _get_llm_pos_ids_for_vision(start_idx, video_idx, + spatial_merge_size, + t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = _split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: List[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = _get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, grid_hs, + grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) * + [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split(1, + dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + +@numba.jit(nopython=True) +def _omni_get_input_positions_numba( + input_tokens: np.ndarray, + image_token_id: int, + video_token_id: int, + audio_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + audio_start_token_id: int, + audio_end_token_id: int, + spatial_merge_size: int, + tokens_per_second: float, + seconds_per_chunk: float, + image_grid_thw: np.ndarray, + video_grid_thw: np.ndarray, + second_per_grid_ts: np.ndarray, + audio_feature_lengths: np.ndarray, + use_audio_in_video: bool, +) -> tuple[np.ndarray, int]: + mrope_pos = np.empty((3, len(input_tokens)), dtype=np.int64) + + cur_t = -1 + + cur_image_idx = -1 + cur_video_idx = -1 + cur_audio_idx = -1 + + i = 0 + while i < len(input_tokens): + token_id = input_tokens[i] + if token_id == image_token_id: + cur_image_idx += 1 + if cur_image_idx >= len(image_grid_thw): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_IMAGE, cur_image_idx) + + i, cur_t = _emit_image_tokens( + mrope_pos, + image_grid_thw=image_grid_thw[cur_image_idx], + i=i, + start_t=cur_t + 1, + spatial_merge_size=spatial_merge_size, + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_IMAGE, + cur_image_idx) + elif token_id == video_token_id and use_audio_in_video: + # audio and vision position ids split into chunks and interleaved. + # + # |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + # |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + + cur_video_idx += 1 + if cur_video_idx >= len(video_grid_thw): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_VIDEO, cur_video_idx) + + cur_audio_idx += 1 + if cur_audio_idx >= len(audio_feature_lengths): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_AUDIO, cur_audio_idx) + + i, cur_t = _emit_video_with_audio( + mrope_pos, + i=i, + video_grid_thw=video_grid_thw[cur_video_idx], + start_t=cur_t + 1, + spatial_merge_size=spatial_merge_size, + tokens_per_second=tokens_per_second, + second_per_grid_t=1.0 \ + if cur_video_idx >= len(second_per_grid_ts) \ + else second_per_grid_ts[cur_video_idx], + seconds_per_chunk=seconds_per_chunk, + audio_feature_length=audio_feature_lengths[cur_audio_idx], + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_VIDEO, + cur_video_idx) + elif token_id == video_token_id: + cur_video_idx += 1 + if cur_video_idx >= len(video_grid_thw): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_VIDEO, cur_video_idx) + + i, cur_t = _emit_video_tokens( + mrope_pos, + i=i, + video_grid_thw=video_grid_thw[cur_video_idx], + start_t=cur_t + 1, + spatial_merge_size=spatial_merge_size, + tokens_per_second=tokens_per_second, + second_per_grid_t=1.0 \ + if cur_video_idx >= len(second_per_grid_ts) \ + else second_per_grid_ts[cur_video_idx], + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_VIDEO, + cur_video_idx) + elif token_id == audio_token_id: + cur_audio_idx += 1 + if cur_audio_idx >= len(audio_feature_lengths): + with numba.objmode(): + _raise_missing_mm_item_error(MM_TYPE_AUDIO, cur_audio_idx) + + i, cur_t = _emit_1d_tokens( + mrope_pos, + i=i, + start_t=cur_t + 1, + num_tokens=_calc_audio_token_num( + audio_feature_lengths[cur_audio_idx]), + ) + + if i == ERR_EXCEEDED: + with numba.objmode(): + _raise_tokens_out_of_bound_error(MM_TYPE_AUDIO, + cur_audio_idx) + elif token_id == audio_start_token_id \ + and use_audio_in_video \ + and i > 0 \ + and input_tokens[i - 1] == vision_start_token_id: + # handling the <|audio_bos|> after <|vision_bos|> + i = _emit_1d_token( + mrope_pos, + i=i, + t=cur_t, + ) + elif token_id == vision_end_token_id \ + and use_audio_in_video \ + and i > 0 \ + and input_tokens[i - 1] == audio_end_token_id: + # handling the <|vision_eos|> after <|audio_eos|> + i = _emit_1d_token( + mrope_pos, + i=i, + t=cur_t, + ) + else: + cur_t += 1 + i = _emit_1d_token( + mrope_pos, + i=i, + t=cur_t, + ) + + num_unused_images = len(image_grid_thw) - cur_image_idx - 1 + if num_unused_images > 0: + with numba.objmode(): + _raise_unused_mm_items_error(MM_TYPE_IMAGE, num_unused_images) + + num_unused_videos = len(video_grid_thw) - cur_video_idx - 1 + if num_unused_videos > 0: + with numba.objmode(): + _raise_unused_mm_items_error(MM_TYPE_VIDEO, num_unused_videos) + + num_unused_audios = len(audio_feature_lengths) - cur_audio_idx - 1 + if num_unused_audios > 0: + with numba.objmode(): + _raise_unused_mm_items_error(MM_TYPE_AUDIO, num_unused_audios) + + mrope_position_delta = cur_t + 1 - len(input_tokens) + return mrope_pos, mrope_position_delta + + +def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, + 1).expand(len(t_index), -1, + llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + +def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> List[List[int]]: + ranges: List[List[int]] = [[] for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + +# following functions are used by: +# - _vl_get_input_positions_numba +# - _omni_get_input_positions_numba + +MM_TYPE_IMAGE = "image_grid_thw" +MM_TYPE_VIDEO = "video_grid_thw" +MM_TYPE_AUDIO = "audio_feature_lengths" +ERR_EXCEEDED = -1 + + +@numba.jit(nopython=True, inline="always") +def _emit_2d_tokens( + mrope_pos: np.ndarray, + i: int, + num_h: int, + num_w: int, + cur_t: int, + start_hw: int, +) -> int: + if i + num_h * num_w > mrope_pos.shape[1]: + return ERR_EXCEEDED + + for h in range(num_h): + for w in range(num_w): + mrope_pos[0, i] = cur_t + mrope_pos[1, i] = start_hw + h + mrope_pos[2, i] = start_hw + w + i += 1 + + return i + + +@numba.jit(nopython=True) +def _emit_image_tokens( + mrope_pos: np.ndarray, + i: int, + image_grid_thw: np.ndarray, + start_t: int, + spatial_merge_size: int, +) -> tuple[int, int]: + num_h = image_grid_thw[1] // spatial_merge_size + num_w = image_grid_thw[2] // spatial_merge_size + for t in range(start_t, start_t + image_grid_thw[0]): + i = _emit_2d_tokens( + mrope_pos, + i=i, + num_h=num_h, + num_w=num_w, + cur_t=t, + start_hw=start_t, + ) + if i == ERR_EXCEEDED: + return ERR_EXCEEDED, ERR_EXCEEDED + + cur_t = start_t + max(image_grid_thw[0], num_h, num_w) - 1 + return i, cur_t + + +@numba.jit(nopython=True) +def _emit_video_tokens( + mrope_pos: np.ndarray, + i: int, + video_grid_thw: np.ndarray, + start_t: int, + spatial_merge_size: int, + tokens_per_second: int, + second_per_grid_t: float, +) -> tuple[int, int]: + num_h = video_grid_thw[1] // spatial_merge_size + num_w = video_grid_thw[2] // spatial_merge_size + + tokens_per_grid_t = tokens_per_second * second_per_grid_t + + for t in range(video_grid_thw[0]): + i = _emit_2d_tokens( + mrope_pos, + i=i, + num_h=num_h, + num_w=num_w, + cur_t=start_t + int(t * tokens_per_grid_t), + start_hw=start_t, + ) + if i == ERR_EXCEEDED: + return ERR_EXCEEDED, ERR_EXCEEDED + + cur_t = start_t + max( + int((video_grid_thw[0] - 1) * tokens_per_grid_t), + num_h - 1, + num_w - 1, + ) + return i, cur_t + + +@numba.jit(nopython=True) +def _emit_video_with_audio( + mrope_pos: np.ndarray, + i: int, + video_grid_thw: np.ndarray, + start_t: int, + spatial_merge_size: int, + tokens_per_second: int, + second_per_grid_t: float, + seconds_per_chunk: float, + audio_feature_length: int, +) -> tuple[int, int]: + video_num_h = video_grid_thw[1] // spatial_merge_size + video_num_w = video_grid_thw[2] // spatial_merge_size + + tokens_per_grid_t = tokens_per_second * second_per_grid_t + + audio_token_num = _calc_audio_token_num(audio_feature_length) + added_audio_token_num = 0 + + t_ntoken_per_chunk = int(seconds_per_chunk * tokens_per_second) + next_chunk_t = start_t + t_ntoken_per_chunk + + for t in range(video_grid_thw[0]): + video_t = start_t + int(t * tokens_per_grid_t) + + # audio tokens + if video_t >= next_chunk_t: + next_chunk_t += t_ntoken_per_chunk + if added_audio_token_num < audio_token_num: + chunked_audio_token_num = min( + t_ntoken_per_chunk, + audio_token_num - added_audio_token_num) + i, _ = _emit_1d_tokens( + mrope_pos, + i=i, + start_t=start_t + added_audio_token_num, + num_tokens=chunked_audio_token_num, + ) + if i == ERR_EXCEEDED: + return ERR_EXCEEDED, ERR_EXCEEDED + added_audio_token_num += chunked_audio_token_num + + # video tokens + i = _emit_2d_tokens( + mrope_pos, + i=i, + num_h=video_num_h, + num_w=video_num_w, + cur_t=video_t, + start_hw=start_t, + ) + if i == ERR_EXCEEDED: + return ERR_EXCEEDED, ERR_EXCEEDED + + # remaining audio tokens + if added_audio_token_num < audio_token_num: + i, _ = _emit_1d_tokens( + mrope_pos, + i=i, + start_t=start_t + added_audio_token_num, + num_tokens=audio_token_num - added_audio_token_num, + ) + if i == ERR_EXCEEDED: + return ERR_EXCEEDED, ERR_EXCEEDED + + cur_t = max(mrope_pos[0, i - 1], mrope_pos[1, i - 1], mrope_pos[2, i - 1]) + return i, cur_t + + +@numba.jit(nopython=True, inline="always") +def _emit_1d_token( + mrope_pos: np.ndarray, + i: int, + t: int, +) -> int: + mrope_pos[0, i] = t + mrope_pos[1, i] = t + mrope_pos[2, i] = t + return i + 1 + + +@numba.jit(nopython=True, inline="always") +def _emit_1d_tokens( + mrope_pos: np.ndarray, + i: int, + start_t: int, + num_tokens: int, +) -> tuple[int, int]: + if i + num_tokens > mrope_pos.shape[1]: + return ERR_EXCEEDED, ERR_EXCEEDED + + for t in range(start_t, start_t + num_tokens): + i = _emit_1d_token( + mrope_pos, + i=i, + t=t, + ) + + return i, start_t + num_tokens - 1 + + +@numba.jit(nopython=True, inline="always") +def _calc_audio_token_num(audio_feature_length: int): + return (((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1) + + +@numba.jit() +def _raise_missing_mm_item_error(mm_type: str, mm_index: int): + raise ValueError(f"Mismatch between input_tokens and {mm_type}" + f" ({mm_type}[{mm_index}] is missing)." + " Please check your prompt and multi_modal_data.") + + +@numba.jit() +def _raise_tokens_out_of_bound_error(mm_type: str, mm_index: int): + raise ValueError( + f"Mismatch between input_tokens and {mm_type}" + f" (input_tokens out of bounds while processing {mm_type}[{mm_index}])." + " Please check your prompt and multi_modal_data.") + + +@numba.jit() +def _raise_unused_mm_items_error(mm_type: str, unused_num: int): + raise ValueError(f"Mismatch between input_tokens and {mm_type}" + f" ({mm_type} has {unused_num} unused items)." + " Please check your prompt and multi_modal_data.") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f8392eb679d2..13f48fdb8170 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,6 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform @@ -1021,470 +1020,6 @@ def forward( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - @classmethod - def get_input_positions( - cls, - input_tokens: List[int], - hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], - video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], - second_per_grid_ts: Optional[List[float]], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> Tuple[List[List[int]], int]: - """Get mrope input positions and delta value.""" - - image_grid_thw = [] if image_grid_thw is None else image_grid_thw - video_grid_thw = [] if video_grid_thw is None else video_grid_thw - second_per_grid_ts = [] if second_per_grid_ts is None else \ - second_per_grid_ts - - llm_positions, mrope_position_delta = \ - cls.get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - return llm_positions.tolist(), mrope_position_delta - - @classmethod - def get_input_positions_tensor( - cls, - input_tokens: List[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: List[float], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> Tuple[torch.Tensor, int]: - from vllm.transformers_utils.config import thinker_uses_mrope - if thinker_uses_mrope(hf_config): - return cls._omni_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - else: - return cls._vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - ) - - @classmethod - def _vl_get_input_positions_tensor( - cls, - input_tokens: List[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: List[float], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> Tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, - "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * - tokens_per_second).long().flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _omni_get_input_positions_tensor( - cls, - input_tokens: List[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - second_per_grid_ts: Optional[List[float]] = None, - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> Tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - - Example: - - (V_i are vision position ids, A_i are audio position ids) - - |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... - |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... - """ - - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. - - thinker_config = hf_config.thinker_config - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] - - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if src_item[idx] not in [ - audio_token_id, video_token_id, image_token_id - ]: - if use_audio_in_video and idx > 0: - if src_item[idx] == vision_end_token_id and \ - src_item[idx - 1] == audio_end_token_id: - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif src_item[idx] == audio_start_token_id and \ - src_item[idx - 1] == vision_start_token_id: - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], - dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: List[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len( - t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - new_src_item.extend([video_token_id] * - vision_ntoken_per_chunk) - vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_chunk, - grid_hs, grid_ws).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len) * [audio_token_id]) - audio_start_idx = start_idx if len( - audio_llm_pos_ids_list - ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 - if min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = (torch.arange( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len)).expand(3, -1) + - audio_start_idx).split(1, - dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id]) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand( - 3, -1) + llm_pos_ids_list[-1].max() + 1).split( - 1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = torch.cat(llm_pos_ids_list, - dim=1).max() + 1 - len(src_item) - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @staticmethod - def _get_llm_pos_ids_for_vision( - start_idx: int, - vision_idx: int, - spatial_merge_size: int, - t_index: List[int], - grid_hs: torch.Tensor, - grid_ws: torch.Tensor, - ) -> torch.Tensor: - llm_pos_ids_list = [] - llm_grid_h = grid_hs[vision_idx] // spatial_merge_size - llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( - len(t_index), -1, llm_grid_w).flatten()) - w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( - len(t_index), llm_grid_h, -1).flatten()) - t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( - -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() - _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) - llm_pos_ids_list.append(_llm_pos_ids + start_idx) - llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) - return llm_pos_ids - - @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, - interval: int) -> List[List[int]]: - ranges: List[List[int]] = [[] - for _ in range((max(lst) // interval) + 1)] - for num in lst: - index = num // interval - ranges[index].append(num) - return ranges - - @staticmethod - def get_next_input_positions( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> List[List[int]]: - return [ - list( - range(context_len + mrope_position_delta, - seq_len + mrope_position_delta)) for _ in range(3) - ] - - @staticmethod - def get_next_input_positions_tensor( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> torch.Tensor: - return torch.arange( - mrope_position_delta + context_len, - mrope_position_delta + seq_len, - ).expand(3, -1) - - @classmethod - def omni_get_updates_use_audio_in_video( - cls, - thinker_config: PretrainedConfig, - audio_len: int, - video_grid_thw: Union[List[int], torch.Tensor], - video_second_per_grid_t: float, - ) -> List[int]: - """Get video prompt updates when `use_audio_in_video` is True. - - In this case, audio and vision update ids will be split into - chunks and interleaved (details in `_omni_get_input_positions_tensor`). - - <|video_bos|><|VIDEO|><|video_eos|> => - <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> - """ - - audio_token_id = thinker_config.audio_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - grid_t = video_grid_thw[0] - grid_h = video_grid_thw[1] - grid_w = video_grid_thw[2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * video_second_per_grid_t * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - - updates = [audio_start_token_id] - added_audio_len = 0 - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( - spatial_merge_size**2) - updates.extend([video_token_id] * vision_ntoken_per_chunk) - - audio_chunk_size = min(t_ntoken_per_chunk, - audio_len - added_audio_len) - updates.extend(audio_chunk_size * [audio_token_id]) - added_audio_len += audio_chunk_size - if added_audio_len < audio_len: - updates.extend((audio_len - added_audio_len) * [audio_token_id]) - updates.extend([audio_end_token_id]) - - return updates - _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd8c87fd9efc..cdc466fe82df 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -20,7 +20,6 @@ from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -264,6 +263,14 @@ def __init__( device="cpu", pin_memory=self.pin_memory) + # NOTE: `mrope_positions_np` shares same + # underlying data with `mrope_positions_cpu`. + # + # `mrope_positions_np` is created in favor of + # numba accelerated func `mrope_assign_next_input_positions` + # while it can operate numpy array only. + self.mrope_positions_np = self.mrope_positions_cpu.numpy() + # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = check_use_alibi(model_config) @@ -382,43 +389,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_input in self.requests[req_id].mm_inputs: - if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( - mm_input["video_grid_thw"].tolist()) - if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.extend( - mm_input["second_per_grid_ts"]) - if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.extend( - mm_input["audio_feature_lengths"]) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - hf_config = self.model_config.hf_config - - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. @@ -771,10 +741,13 @@ def _compute_cascade_attn_prefix_len( return common_prefix_len if use_cascade else 0 def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + from vllm.model_executor.layers.mrope_positions import ( + mrope_assign_next_input_positions, + mrope_get_input_positions_and_delta) + mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] - assert req.mrope_positions is not None num_computed_tokens = \ self.input_batch.num_computed_tokens_cpu[index] @@ -782,6 +755,41 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = len(req.prompt_token_ids) + if req.mrope_positions is None: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend( + mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"]) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req.mrope_positions, \ + req.mrope_position_delta = \ + mrope_get_input_positions_and_delta( + input_tokens=self.input_batch.token_ids_cpu[ + index, :num_prompt_tokens], + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) @@ -808,17 +816,27 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): if completion_part_len > 0: # compute completion's mrope_positions on-the-fly dst_start = mrope_pos_ptr - dst_end = mrope_pos_ptr + completion_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - MRotaryEmbedding.get_next_input_positions_tensor( - req.mrope_position_delta, - context_len=num_computed_tokens + - prompt_part_len, - seq_len=num_computed_tokens + - prompt_part_len + - completion_part_len, - ) + # keep them for benchmarking purpose temporarily + #dst_end = mrope_pos_ptr + completion_part_len + + mrope_assign_next_input_positions( + out=self.mrope_positions_np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + # self.mrope_positions_cpu[:, dst_start:dst_end] = \ + # mrope_get_next_input_positions_tensor( + # req.mrope_position_delta, + # context_len=num_computed_tokens + + # prompt_part_len, + # seq_len=num_computed_tokens + + # prompt_part_len + + # completion_part_len, + # ) mrope_pos_ptr += completion_part_len diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 710ca1a13b0c..bfaafb6f1fa4 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -18,7 +18,8 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.mrope_positions import ( + mrope_get_input_positions_and_delta, mrope_get_next_input_positions) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_lora, supports_multimodal @@ -267,7 +268,7 @@ def _compute_decode_input_tokens(self, data: ModelInputData, # For MRotaryEmbedding if seq_data.mrope_position_delta is not None: - next_pos = MRotaryEmbedding.get_next_input_positions( + next_pos = mrope_get_next_input_positions( seq_data.mrope_position_delta, context_len, seq_len, @@ -388,7 +389,7 @@ def _compute_multi_modal_input(self, token_ids = seq_data.get_token_ids() mrope_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( + mrope_get_input_positions_and_delta( token_ids, hf_config=hf_config, image_grid_thw=image_grid_thw, @@ -398,6 +399,7 @@ def _compute_multi_modal_input(self, audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + mrope_positions = mrope_positions.tolist() seq_data.mrope_position_delta = mrope_position_delta for i in range(3): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d96021cc688e..72f065c00e05 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -34,7 +34,8 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.mrope_positions import ( + mrope_get_input_positions_and_delta, mrope_get_next_input_positions) from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, get_sampler) from vllm.model_executor.model_loader import get_model @@ -565,7 +566,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.mrope_input_positions = [None] * inter_data.n_seqs inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( + seq_idx] = mrope_get_next_input_positions( seq_data.mrope_position_delta, context_len, seq_len, @@ -747,7 +748,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, token_ids = seq_data.get_token_ids() mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( + mrope_get_input_positions_and_delta( token_ids, hf_config=hf_config, image_grid_thw=image_grid_thw, @@ -758,6 +759,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + mrope_input_positions = mrope_input_positions.tolist() seq_data.mrope_position_delta = mrope_position_delta inter_data.mrope_input_positions[