Skip to content
Closed
38 changes: 38 additions & 0 deletions examples/whisper_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import vllm
import torch
import requests
from vllm import LLM
from datasets import Audio


def main():
sr = 16000
audio = Audio(sampling_rate=sr)
llm = LLM(
model="openai/whisper-large-v3",
max_num_seqs = 1,
max_model_len = 448,
gpu_memory_utilization = 0.4,
dtype = 'bfloat16',
)

r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/7021-79759-0004.wav')
y = audio.decode_example(audio.encode_example(r.content))['array']

output_lang = llm.generate({
"prompt_token_ids": [50258],
"whisper_data": y,
}, sampling_params = SamplingParams(max_tokens = 1, temperature = 0))

outputs = llm.generate({
"prompt_token_ids": [50258, output_lang[0].outputs[0].token_ids[0], 50360],
"whisper_data": y,
}, sampling_params = SamplingParams(max_tokens = 100, temperature = 0))

# ' without going to any such extreme as this we can easily see on reflection how vast an influence on the'
print(outputs[0].outputs[0].text)



if __name__ == "__main__":
main()
40 changes: 40 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_WHISPER_MAX_NUM_BATCHED_TOKENS = 448


class ModelConfig:
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
self._verify_whisper_mode()
self._verify_quantization()
self._verify_cuda_graph()

Expand All @@ -165,6 +167,11 @@ def _verify_embedding_mode(self) -> None:
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)

def _verify_whisper_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.whisper_mode = any(
ModelRegistry.is_whisper_model(arch) for arch in architectures)

def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
Expand Down Expand Up @@ -682,6 +689,7 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
whisper_mode: Whether the running model is for whisper.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
Expand All @@ -699,6 +707,7 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
whisper_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -711,6 +720,9 @@ def __init__(self,
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
elif whisper_mode:
self.max_num_batched_tokens = max(
max_model_len, _WHISPER_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
Expand All @@ -725,6 +737,7 @@ def __init__(self,
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.whisper_mode = whisper_mode
self.preemption_mode = preemption_mode

self._verify_args()
Expand Down Expand Up @@ -1218,6 +1231,30 @@ def as_cli_args_dict(self) -> Dict[str, Any]:

return result

@dataclass
class WhisperConfig:
whisper_input_type: Optional[str] = 'input_features'
whisper_processor: Optional[str] = 'openai/whisper-large-v3'
whisper_processor_revision: Optional[str] = 'openai/whisper-large-v3'
sample_rate: Optional[int] = 16000

def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.

Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value

return result


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
Expand Down Expand Up @@ -1299,6 +1336,8 @@ def _get_and_verify_max_len(
"max_sequence_length",
"max_seq_length",
"seq_len",
# Whisper
"max_length",
]
# Choose the smallest "max_length" from the possible keys.
max_len_key = None
Expand Down Expand Up @@ -1435,6 +1474,7 @@ class EngineConfig:
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
whisper_config: Optional[WhisperConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
Expand Down
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
whisper_data=seq_group.whisper_data
)
seq_group_metadata_list.append(seq_group_metadata)

Expand Down
55 changes: 54 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
VisionLanguageConfig, WhisperConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple

Expand Down Expand Up @@ -88,6 +88,12 @@ class EngineArgs:
image_processor_revision: Optional[str] = None
disable_image_processor: bool = False

# Related to Whisper
whisper_input_type: Optional[str] = None
whisper_processor: Optional[str] = None
whisper_processor_revision: Optional[str] = None
sample_rate: Optional[int] = 16000

scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False

Expand Down Expand Up @@ -156,6 +162,38 @@ def add_cli_args_for_vlm(

return parser

@staticmethod
def add_cli_args_for_whisper(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--whisper-input-type',
type=nullable_str,
default=EngineArgs.whisper_input_type,
choices=[
'input_features'
],
help=('The audio input type for whisper passed into vLLM.'))
parser.add_argument(
'--whisper-processor',
type=str,
default=EngineArgs.whisper_processor,
help='Name or path of the huggingface whisper processor to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--whisper-processor-revision',
type=str,
default=None,
help='Revision of the huggingface whisper processor version to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--sample-rate',
type=int,
default=EngineArgs.sample_rate,
help='sample rate for whisper processor')

return parser

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
Expand Down Expand Up @@ -513,6 +551,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)
parser = EngineArgs.add_cli_args_for_whisper(parser)

parser.add_argument(
'--scheduler-delay-factor',
Expand Down Expand Up @@ -717,6 +756,7 @@ def create_engine_config(self, ) -> EngineConfig:
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
whisper_mode=model_config.whisper_mode,
preemption_mode=self.preemption_mode,
)
lora_config = LoRAConfig(
Expand Down Expand Up @@ -772,6 +812,18 @@ def create_engine_config(self, ) -> EngineConfig:
)
else:
vision_language_config = None

if self.whisper_input_type:
if self.whisper_processor is None:
self.whisper_processor = self.model
whisper_config = WhisperConfig(
self.whisper_input_type,
self.whisper_processor,
self.whisper_processor_revision,
self.sample_rate,
)
else:
whisper_config = None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
Expand All @@ -794,6 +846,7 @@ def create_engine_config(self, ) -> EngineConfig:
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
whisper_config=whisper_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config,
Expand Down
17 changes: 16 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,24 @@ async def process_model_inputs_async(
else:
prompt_token_ids = inputs["prompt_token_ids"]

if 'whisper_data' in inputs:
if self.whisper_config is None:
raise ValueError(f"Whisper config is None, must initialize a Whisper model.")
if self.whisper_processor is None:
raise ValueError(f"Whisper Processor is not initialized.")
whisper_data = self.whisper_processor(
inputs['whisper_data'],
sampling_rate = self.whisper_config.sample_rate,
return_tensors = 'pt',
)
whisper_data = whisper_data.to(self.model_config.dtype).input_features[0]
else:
whisper_data = None

return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
multi_modal_data=inputs.get("multi_modal_data"),
whisper_data=whisper_data)

async def add_request_async(
self,
Expand Down
36 changes: 32 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
VisionLanguageConfig, WhisperConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
Expand All @@ -36,6 +36,7 @@
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.whisper_processor import cached_get_whisper_processor
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
whisper_config: Optional[WhisperConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__(
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.whisper_config = whisper_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
Expand All @@ -223,10 +226,17 @@ def __init__(
self.tokenizer = None
self.detokenizer = None

if self.whisper_config is not None:
self.whisper_processor = cached_get_whisper_processor(
self.whisper_config.whisper_processor
)
else:
self.whisper_processor = None

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
Expand All @@ -235,6 +245,7 @@ def __init__(
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
whisper_config=whisper_config,
speculative_config=speculative_config,
load_config=load_config,
)
Expand Down Expand Up @@ -501,19 +512,36 @@ def process_model_inputs(
if isinstance(inputs, str):
inputs = {"prompt": inputs}

if 'whisper_data' in inputs:
if self.whisper_config is None:
raise ValueError(f"Whisper config is None, must initialize a Whisper model.")
if self.whisper_processor is None:
raise ValueError(f"Whisper Processor is not initialized.")
whisper_data = self.whisper_processor(
inputs['whisper_data'],
sampling_rate = self.whisper_config.sample_rate,
return_tensors = 'pt',
)
whisper_data = whisper_data.to(self.model_config.dtype).input_features[0]
else:
whisper_data = None

if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")

prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
lora_request=lora_request,
add_special_tokens=self.whisper_processor is None)
else:
prompt_token_ids = inputs["prompt_token_ids"]


return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
multi_modal_data=inputs.get("multi_modal_data"),
whisper_data=whisper_data)

def add_request(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,4 @@ def _run_engine(
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
return sorted(outputs, key=lambda x: int(x.request_id))
Loading