From 1da44c51fc33e7a0a14693d001da27a4a811068e Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 5 Feb 2025 13:22:42 -0800 Subject: [PATCH 01/27] Add support for Phi-4-multimodal-instruct Signed-off-by: Congcong Chen --- examples/offline_inference_phi3o.py | 531 +++ vllm/config.py | 4 +- vllm/entrypoints/chat_utils.py | 4 + vllm/lora/models.py | 3 + vllm/model_executor/models/phi3s_utils.py | 3498 +++++++++++++++++ vllm/model_executor/models/phi4o.py | 1761 +++++++++ vllm/model_executor/models/registry.py | 1 + .../models/vision_siglip_navit.py | 1722 ++++++++ 8 files changed, 7522 insertions(+), 2 deletions(-) create mode 100644 examples/offline_inference_phi3o.py create mode 100644 vllm/model_executor/models/phi3s_utils.py create mode 100644 vllm/model_executor/models/phi4o.py create mode 100644 vllm/model_executor/models/vision_siglip_navit.py diff --git a/examples/offline_inference_phi3o.py b/examples/offline_inference_phi3o.py new file mode 100644 index 000000000000..668834168049 --- /dev/null +++ b/examples/offline_inference_phi3o.py @@ -0,0 +1,531 @@ +# Implements a simple offline inference script for the Phi 3.5 Speech model. +# Code implemented by Jacob Platin (jacobplatin@microsoft.com) + +import soundfile + +from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser +from vllm.lora.request import LoRARequest +from vllm.multimodal.utils import fetch_image + +""" +Model file: vllm/model_executor/models/phi3o.py + +Step 1: Download the following model weights to some location. +* Base Model Weight: https://github.com/microsoft/MoE/tree/weijian/phio-hf +* Vision Lora Model Weight: https://llmpretrainingwus3.blob.core.windows.net/users/weijianxu/phio-004-sft-vision-lora-only-from-hf-unified-model +* Speech Lora Model Weight: https://llmpretrainingwus3.blob.core.windows.net/users/weijianxu/phio-004-sft-speech-lora-only-from-hf-unified-model + +Step 2: Run the test +* Run the followling command with the commandline parameters you want to pass into the script. + python examples/offline_inference_phi3s.py +* You should expect to see the output like: + Prompt: '<|user|>\n<|image_1|>\n<|audio_1|>\ntry your best to answer the question<|end|>\n<|assistant|>\n' + Generated text: 'Stop' +""" +def main_pure_text(args: dict) -> None: + """ + Main function for the offline inference script. + """ + llm = LLM( + model=args.model_path, + trust_remote_code=True, + enforce_eager=True) + user_prompt = '<|user|>\n' + assistant_prompt = '<|assistant|>\n' + prompt_suffix = '<|end|>\n' + prompt = f'{user_prompt}what is the answer for 1+1? Explain it.{prompt_suffix}{assistant_prompt}' + print(f'>>> Prompt\n{prompt}') + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = { + "prompt": prompt + } + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate(generate_args, sampling_params=sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + + +def main_with_lora_speech(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + wav_paths = [args.wav_path] + llm = LLM( + model=args.model_path, + trust_remote_code=True, + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + limit_mm_per_prompt={"audio": len(wav_paths)}, + max_loras=5) + + # assert len(wav_paths) == 1, "Only support single audio files for now!" + + prompt = "Generate a comprehensive text transcription of the spoken content." + placeholders = "\n".join( + f"<|audio_{i}|>" for i in range(1, len(wav_paths) + 1) + ) + prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" + + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = { + "prompt": prompt, + "multi_modal_data": { + "audio": [soundfile.read(wav_path) for wav_path in wav_paths] + } + } + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate(generate_args, sampling_params=sampling_params, lora_request= [LoRARequest("speech_adapter", 3, args.speech_lora_path)] if activate_lora_request else None) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + +def main_with_lora_speech_batch(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + wav_paths = [args.wav_path, args.wav_path] + + llm = LLM( + model=args.model_path, + trust_remote_code=True, + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + limit_mm_per_prompt={"audio": len(wav_paths)}, + max_loras=5) + + + # assert len(wav_paths) == 1, "Only support single audio files for now!" + + prompt = "Based on the attached audio, generate a comprehensive text transcription of the spoken content." + placeholders = "\n".join( + f"<|audio_{i}|>" for i in range(1, len(wav_paths) + 1) + ) + prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" + + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = [ + { + "prompt": prompt, + "multi_modal_data": { + "audio": [soundfile.read(wav_path) for wav_path in wav_paths] + } + }, + { + "prompt": prompt, + "multi_modal_data": { + "audio": [soundfile.read(wav_path) for wav_path in wav_paths] + } + }, + ] + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request= LoRARequest("speech_adapter", 3, args.speech_lora_path) + if activate_lora_request else None) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + +def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + image_urls=[args.image_url] + llm = LLM( + model=args.model_path, + trust_remote_code=True, + + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + max_loras=5, + # max_model_len=4096, + # max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + # prompt = "what's the traffic sign in the image" + prompt = "What is shown in this image?" + + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" + + image_data=[fetch_image(url) for url in image_urls] + + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = { + "prompt": prompt, + "multi_modal_data": { + "image": image_data, + }, + } + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request= [LoRARequest("vision_adapter", 3, args.vision_lora_path)] if activate_lora_request else None + ) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + + +def main_with_lora_vision_batch(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + image_urls=[args.image_url, "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"] + llm = LLM( + model=args.model_path, + trust_remote_code=True, + + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + max_loras=5, + + # max_model_len=4096, + # max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + # prompt = "what's the traffic sign in the image" + prompt = "What is shown in this image?" + + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" + + # image_data=[fetch_image(url) for url in image_urls] + + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = [ + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in ["https://www.ilankelman.org/stopsigns/australia.jpg", "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"]], + }, + }, + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in image_urls], + }, + }, + ] + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request= LoRARequest("vision_adapter", 3, args.vision_lora_path) + if activate_lora_request else None + ) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + + +def main_with_lora_vision_speech(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + image_urls=[args.image_url] + llm = LLM( + model=args.model_path, + trust_remote_code=True, + + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + max_loras=5, + + # max_model_len=4096, + # max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + prompt = "" + + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>\n<|assistant|>\n" + + image_data=[fetch_image(url) for url in image_urls] + + wav_paths = ["/scratch/turing_westus3_prm_data/users/congcongchen/MoE_2/hf-models/phio/examples/what_is_the_traffic_sign_in_the_image.wav"] + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = { + "prompt": prompt, + "multi_modal_data": { + "image": image_data, + "audio": [soundfile.read(wav_path) for wav_path in wav_paths], + }, + } + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request= [LoRARequest("vision_adapter", 3, args.vision_lora_path)] if activate_lora_request else None + ) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + +def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) -> None: + """ + Main function for the offline inference script. + """ + image_urls=[args.image_url, "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"] + wav_paths = [args.wav_path] + llm = LLM( + model=args.model_path, + trust_remote_code=True, + + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + max_loras=5, + + # max_model_len=40960, + # max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls), "audio": len(wav_paths)}, + ) + + prompt = "try your best to answer the question" + + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>\n<|assistant|>\n" + + # image_data=[fetch_image(url) for url in image_urls] + + + # NOTE: soundfile.read will return the audio feature and the sampling rate + generate_args = [ + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in image_urls], + "audio": [soundfile.read(wav_path) for wav_path in wav_paths], + }, + }, + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in ["https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg", "https://alinasayre.com/wp-content/uploads/2012/01/c3a7c-dsc01668.jpg"]], + "audio": [soundfile.read(wav_path) for wav_path in wav_paths], + }, + }, + ] + # NOTE: you should use the following settings to ensure parity in HF + # generate_ids = model.generate( + # **inputs, + # top_p=1, + # max_new_tokens=1200, + # temperature=0, + # use_cache=False, + # min_p=0, + # top_k=-1, + # ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=1200, + ) + + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request= LoRARequest("vision_adapter", 3, args.vision_lora_path) + if activate_lora_request else None + ) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}\n\n") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Demo on using vLLM for offline inference with " + "vision language models that support multi-image input" + ) + parser.add_argument( + "--model-path", + "-p", + type=str, + default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/merged/speech-merged-lora-and-base_model-from-hf-unified-model", + help="Path to the (HuggingFace) model checkpoint.", + ) + + parser.add_argument( + "--vision-lora-path", + "-v", + type=str, + default="/modelblob/users/weijianxu/phi-o/vision-speech-merged-pretraining/official_run/Phio-SFT-long-001-DPO-002/merged-vision-mframerc1.2abl1-2.1k-speech-shadow50k-postsr002-posttrain-vision12k-trial2/vllm_lora/vision-lora-only-from-hf-unified-model/", + help="Path to the (HuggingFace) vision lora model checkpoint.", + ) + + parser.add_argument( + "--speech-lora-path", + "-s", + type=str, + default="/modelblob/users/weijianxu/phi-o/vision-speech-merged-pretraining/official_run/Phio-SFT-long-001-DPO-002/merged-vision-mframerc1.2abl1-2.1k-speech-shadow50k-postsr002-posttrain-vision12k-trial2/vllm_lora/speech-lora-only-from-hf-unified-model/", + help="Path to the (HuggingFace) vision lora model checkpoint.", + ) + + parser.add_argument( + "--wav-path", + "-w", + type=str, + default= + "30s_test_6.wav", + help="Path to the audio file.", + ) + + parser.add_argument( + "--image-url", + "-i", + type=str, + default="https://www.ilankelman.org/stopsigns/australia.jpg", + ) + + parser.add_argument( + "--test-type", + "-t", + type=str, + default="speech_language_with_lora", + ) + + args = parser.parse_args() + ##### Language Only ##### + test_type = args.test_type + if test_type == "language_only": + main_pure_text(args) + ##### Speech + Language ##### + elif test_type == "speech_language_with_lora": + main_with_lora_speech(args) + elif test_type == "speech_language_with_lora_batch": + main_with_lora_speech_batch(args) + elif test_type == "speech_language_without_lora": + main_with_lora_speech(args, activate_lora_request=False) + ##### Vision + Language ##### + elif test_type == "vision_language_with_lora": + main_with_lora_vision(args) + elif test_type == "vision_language_with_lora_batch": + main_with_lora_vision_batch(args) + elif test_type == "vision_language_without_lora": + main_with_lora_vision(args, activate_lora_request=False) + ##### Vision + Speech + Language ##### + elif test_type == "vision_speech_language_with_lora": + main_with_lora_vision_speech(args) + elif test_type == "vision_speech_language_with_lora_batch": + main_with_lora_vision_speech_batch(args) + elif test_type == "vision_speech_language_without_lora": + main_with_lora_vision_speech(args, activate_lora_request=False) diff --git a/vllm/config.py b/vllm/config.py index f87d2d6e82cf..8714b50c6c1a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2284,9 +2284,9 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - # Setting the maximum rank to 256 should be able to satisfy the vast + # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256) + possible_max_ranks = (8, 16, 32, 64, 128, 256, 512) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b05842dd27d3..6a1706d828b2 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -395,6 +395,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" + if model_type == "phio": + return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", @@ -424,6 +426,8 @@ def _placeholder_str(self, modality: ModalityStr, elif modality == "audio": if model_type == "ultravox": return "<|audio|>" + if model_type == "phio": + return "<|endoftext11|>" # 200011 (see vocab.json in hf model) if model_type == "qwen2_audio": return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1294884ac2a..26fe835d02e8 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -389,6 +389,7 @@ def activate_adapter( for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: + logger.debug(f"setting {module_name}") module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. bias = module_lora.bias @@ -404,6 +405,7 @@ def activate_adapter( module_lora.embeddings_tensor, module_lora.bias) else: + logger.debug(f"resetting {module_name}") module.reset_lora(index) return True @@ -505,6 +507,7 @@ def _create_lora_modules(self): # aims to prevent this error if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): + logger.debug("-------- Skipping %s --------", module_name) continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) diff --git a/vllm/model_executor/models/phi3s_utils.py b/vllm/model_executor/models/phi3s_utils.py new file mode 100644 index 000000000000..7387063a8214 --- /dev/null +++ b/vllm/model_executor/models/phi3s_utils.py @@ -0,0 +1,3498 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +"""ConformerEncoder Module""" +import abc +import backoff +from functools import partial +import math +from typing import Optional, Tuple, Union, List, Literal, Union, Dict, Callable + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + checkpoint_wrapper, + offload_wrapper, + CheckpointImpl, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel, +) +from torch.utils.checkpoint import checkpoint +from transformers import PretrainedConfig + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + + +class Block(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze( + -1 + ) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ + :, 1 + ] # idx size: [x_len] + boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, + input_dim, + output_dim, + kernel_size, + glu_type="sigmoid", + bias_in_glu=True, + causal=False, + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1), + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * ( + x[:, self.output_dim : self.output_dim * 2, :] + ) + else: + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, + depthwise_seperable_out_channel, + 1, + 1, + 0, + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + if export: # Inference only. + padding = 0 # A cache is concatenated to the left. No padding in the kernel. + else: + # Training only. Padding will be added symmetrically on both sides. + # After convolution, clip off kernel_size-1 points on the right. + padding = kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear( + input_dim * depthwise_multiplier, input_dim + ) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( + nn.Identity() + ) # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, + self.ext_pw_out_channel, + self.glu_type, + self.bias_in_glu, + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, : -(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, : -(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section 2.1 of + the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position + of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel position distances into + logarithmically increasing buckets. This is supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't need length + generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default implementation treats + L->R and R->L the same. Asymmetric was found to yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the size of the learnable + bias parameter. Bucketing is not yet supported, so this defaults to -1 which means + no bucketing is used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With num_buckets=-1, this directly + controls the max size of the bias parameter. When num_buckets > 0 is supported, this + will control the maximum distance for logarithmic bucketing after which all positions + are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias + params to distinguish L->R from R->L. This was found to be better for the encoder. + """ + + def __init__( + self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False + ): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError( + "T5 attention bias with bucketed positions is not yet tested" + ) + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange( + maxpos, device=x.device, dtype=torch.long + )[:, None] + memory_position = torch.arange( + maxpos, device=x.device, dtype=torch.long + )[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance + ) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( + 0 + ) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference + # this also needs to be extended to support asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long + ) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here + + +@backoff.on_exception(backoff.expo, Exception, max_tries=10) +def np_loadtxt_with_retry(filepath): + """np.loadtxt with retry + + Args: + filepath: str + file path to the numpy array. + """ + result = np.loadtxt(filepath, dtype="f") + return result + + +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will substract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.register_buffer("global_mean", torch.zeros(input_size)) + self.register_buffer("global_invstd", torch.ones(input_size)) + self.global_mean: Optional[Tensor] + self.global_invstd: Optional[Tensor] + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): + """Load feature mean and variance used for normalization. + + Args: + mean_file: str + path to the feature mean statistics file. + invstd_file: str + path to the features inverted standard deviation + statistics file. + cuside_features: bool + Boolean that indicates CUSIDE is being used. + The statistics of CUSIDE features are copied + from the normal features + """ + self.global_mean.data = torch.from_numpy( + np_loadtxt_with_retry(mean_file) + ) + self.global_invstd.data = torch.from_numpy( + np_loadtxt_with_retry(invstd_file) + ) + + if cuside_features: + self.global_mean.data = torch.cat( + (self.global_mean.data, self.global_mean.data), 0 + ) + self.global_invstd.data = torch.cat( + (self.global_invstd.data, self.global_invstd.data), 0 + ) + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError( + "No striding allowed for non-symmetric convolutions!" + ) + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError( + "Argument padding should be set to None for CausalConv2D." + ) + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + if self.training: + x = F.pad( + x, + pad=( + self._left_padding, + self._right_padding, + self._left_padding, + self._right_padding, + ), + ) + else: + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for + Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, + and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce + FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions + after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) + 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access + to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ( + "dw_striding", + "striding", + "striding_conv1d", + ) + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == 1 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 2 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear( + conv_channels * int(out_length), feat_out + ) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) + """ + # Unsqueeze Channel Axis + if self.conv2d_subsampling: + x = x.unsqueeze(1) + # Transpose to Channel First mode + else: + x = x.transpose(1, 2) + + # split inputs if chunking_factor is set + if ( + self.subsampling_conv_chunking_factor != -1 + and self.conv2d_subsampling + ): + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only if needed + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = ( + 2**31 / self._conv_channels * self._stride * self._stride + ) + if torch.numel(x) > x_ceil: + need_to_split = True + else: + need_to_split = False + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_( + self.conv[idx].weight, -dw_max, dw_max + ) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_( + self.conv[idx + 1].weight, -pw_max, pw_max + ) + torch.nn.init.uniform_( + self.conv[idx + 1].bias, -pw_max, pw_max + ) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 + fc_scale = ( + self._feat_out * self._feat_in / self._sampling_num + ) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return ( + torch.cat( + [ + self.conv(chunk) + for chunk in torch.split(x, new_batch_size, 0) + ] + ), + True, + ) + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv( + self.conv[i * 3 + 2], new_c, x + ) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [ + self.conv[i * 3 + 3](chunk) + for chunk in torch.split(x, new_t, 2) + ], + 2, + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor( + self, subsampling_conv_chunking_factor: int + ): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length( + lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1 +): + """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = ( + torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + ) + if ceil_mode: + lengths = torch.ceil(lengths) + else: + lengths = torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(Block, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT + be used in ONNX decoding due to a lack of support. In that case, we use the original + attention implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = ( + use_pt_scaled_dot_product_attention + ) + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization is + # enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view( + n_batch, -1, self.h, self.d_k + ) # (b, t, d) + k = self.linear_k(key).view( + n_batch, -1, self.h_k, self.d_k + ) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if ( + self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting() + ): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = ( + q.contiguous() + .view(n_batch * self.h, -1, self.d_k) + .transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul( + reshape_q, pos_k.transpose(-2, -1) + ) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view( + n_batch, self.h, pos_k.size(0), pos_k.size(1) + ) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) + x = x + attn_v + x = ( + x.transpose(1, 2) + .contiguous() + .view(n_batch, -1, self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +def validate_checkpointing_config(activation_checkpointing): + """validate activation checkpointing configuration""" + if isinstance(activation_checkpointing, str): + assert activation_checkpointing in ( + "", + "checkpoint", + "offload", + ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." + elif isinstance(activation_checkpointing, dict): + assert activation_checkpointing.get("module", "transformer") in ( + "transformer", + "attention", + ), "module in activation_checkpointing has to be in ('transformer', 'attention')." + else: + raise ValueError("activation_checkpointing has to be a str or dict.") + + +def embedding_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], +) -> Callable: + """return encoder embedding activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + enabled = activation_checkpointing.get("embed", False) + if enabled: + offloading = activation_checkpointing.get("offload", False) + if offloading: + return offload_wrapper + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", False) + else CheckpointImpl.NO_REENTRANT + ) + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + raise ValueError("Invalid activation_checkpointing config") + + +def encoder_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], + layer_cls: type, + idx: int = 0, +) -> Callable: + """return encoder activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + if target_layer_cls.lower() == "transformer": + target_layer_cls = ( + "EncoderLayer", + "ConformerEncoderLayer", + ) + elif target_layer_cls.lower() == "attention": + target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") + checkpointing_interval = activation_checkpointing.get("interval", 1) + offloading = activation_checkpointing.get("offload", False) + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", True) + else CheckpointImpl.NO_REENTRANT + ) + + if ( + idx % checkpointing_interval == 0 + and layer_cls.__name__ in target_layer_cls + ): + if offloading: + return offload_wrapper + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + + raise ValueError("Invalid activation_checkpointing config") + + +def attn_checkpointing( + activation_checkpointing: Union[str, Dict], i +) -> Union[str, Dict]: + """return activation checkpointing config for attention layer""" + if isinstance(activation_checkpointing, str): + return "" + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + checkpointing_interval = activation_checkpointing.get("interval", 1) + if target_layer_cls == "attention" and i % checkpointing_interval == 0: + return activation_checkpointing + return "" + + raise ValueError("Invalid activation_checkpointing config") + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def repeat(repeat_num, module_gen_fn): + """repeat module N times + + :param int repeat_num: repeat time + :param function module_gen_fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, otional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = encoder_checkpoint_wrapper( + activation_checkpointing, + MultiHeadedAttention, + )( + MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal[ + "feat", "feat_time", "none", True + ] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args + else None + ) + if self.relative_attention_bias_type == "t5": + assert ( + self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get( + "t5_bias_symmetric", False + ), + ) + else: + raise NotImplementedError + + def post_init(self, init_model_config): + + pretrained_speech_encoder_path = init_model_config.get( + "pretrained_speech_encoder_path", None + ) + if pretrained_speech_encoder_path: + model_state = torch.load( + pretrained_speech_encoder_path, map_location="cpu" + ) + encoder_state_dict = {} + for k, v in model_state.items(): + if "encoder." in k: + tmp_k = k.replace("encoder.", "") + encoder_state_dict[tmp_k] = v + + if hasattr(self, "encoder_embedding"): + del self.encoder_embedding + self.load_state_dict(encoder_state_dict) + + if not hasattr(self, "encoder_embedding"): + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"] + ) + + mean_file = init_model_config.get("mean_file", None) + invstd_file = init_model_config.get("invstd_file", None) + if mean_file is not None and invstd_file is not None: + self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that computed + the right thing. That does not work within Torchscript. If you really + need this to be faster, create nn.Module()-s for all the cases and return + one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = ( + math.ceil if isinstance(feature_lens, int) else torch.ceil + ) + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list" + ) + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor + ) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings( + self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None + ): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The squence length after time reduction is invalid: {seq_len}. + Your input feature is too short. Consider filtering out the very + short sentence from data loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d",) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the lang_dict, + only used for multiseed/multilingual models. default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, otional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention + in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming decoding, use + "replication" padding for the cache at start of utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal[ + "feat", "feat_time", "none", True + ] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( + self.embed + ) + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding + ) + assert ( + self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = repeat( + num_blocks, + lambda i: encoder_checkpoint_wrapper( + activation_checkpointing, ConformerEncoderLayer, i + )( + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing( + activation_checkpointing, i + ), + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + ), + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = ( + torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) + < padding_length.unsqueeze(1) + ) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks + ) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 #maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask + + layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias( + input_tensor + ) + + _simplified_path = ( + self.extra_layer_output_idx == -1 + and relative_attention_bias is None + ) + + if _simplified_path: + input_tensor, *_ = self.encoders( + input_tensor, pos_k, pos_v, hs_mask + ) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + if i == self.extra_layer_output_idx: + layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + def gradient_checkpointing_enable(self): + pass + + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) + self.window_size = window_size + self.gradient_checkpointing_enable = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = False + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + if self.gradient_checkpointing_enable and self.training: + q = checkpoint( + layer.__call__, + q, + embed_chunk, + None, + mask, + use_reentrant=True, + ) + else: + q = layer( + tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask + ) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = ( + config.n_embd if hasattr(config, "n_embd") else config.hidden_size + ) + + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = ( + config.embd_pdrop + if hasattr(config, "embd_pdrop") + else config.embed_pdrop + ) + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + # if isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == 'whisper': + # model_path = config.audio_processor.get('pretrained_model_path', None) + # whisper_model = WhisperModel.from_pretrained(model_path) + + # self.encoder = whisper_model.encoder + # n_mels = self.encoder.num_mel_bins + # audio_dim_out = self.encoder.layers[0].embed_dim + # elif isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == "w2vbert2": + # audio_processor_path = config.audio_processor.get("model_path", "facebook/w2v-bert-2.0") + # self.encoder = Wav2Vec2BertModel.from_pretrained(audio_processor_path) + # audio_dim_out = self.encoder.config.hidden_size + # self.layer_idx = config.audio_processor.get("layer", 18) + # self.encoder.config.apply_spec_augment = False + # self.encoder.config.mask_time_prob = 0 + # self.encoder.config.output_hidden_states = True + # n_mels = 160 + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + # fake initialization, create encoder_embedding layer only so that + # in decoding, all parameters can be loaded in from_pretrained_function + # in training, we do post init after from_pretrained function to make sure the correct initialization + self.encoder.post_init({}) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError(f"") + + assert ( + audio_dim_out is not None + ), "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get( + "freeze_audio_processor", False + ) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert ( + self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + self.conv_ds = None + + enable_gradient_checkpointing = kwargs.get( + "enable_gradient_checkpointing", False + ) + if enable_gradient_checkpointing: + self.encoder.gradient_checkpointing_enable() + + if self.qformer: + self.qformer.enable_gradient_checkpointing() + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) + layers = [ + nn.Linear( + audio_dim_out * self.linear_downsample_rate, dim_projection + ) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), nn.Linear(dim_projection, dim_projection)] + ) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a seperate projection layer + layers = [ + nn.Linear( + audio_dim_out * self.linear_downsample_rate, dim_projection + ) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), nn.Linear(dim_projection, dim_projection)] + ) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented" + ) + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def post_init(self, audio_config): + # execute after the from_pretrained() initialization of the phi3 model + if audio_config.get("name", None) == "cascades": + init_model_config = audio_config.get("init_model", {}) + self.encoder.post_init(init_model_config) + # remove the init model in config so it is not saved in the config. + # This might affect the model loading in resuming training and decoding. + if "init_model" in audio_config: + audio_config.pop("init_model") + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes( + self, audio_embed_sizes: torch.LongTensor + ) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ): + + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder( + input_embeds, audio_attention_mask + ) + else: + audio_features, masks = self.encoder( + input_embeds, audio_attention_mask + ) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == 'speech': + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == 'vision': + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError(f"audio_projection_mode = {audio_projection_mode} not implemented") + + return audio_set_tensor + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + audio_embed_sizes, + **kwargs, + ) -> torch.FloatTensor: + """ + arguments: + input_ids: input text ids (B, U) + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + assert input_embeds is not None and len(input_embeds) == len( + audio_embed_sizes + ) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=False) + + if not isinstance(input_embeds, list): + input_embeds = [input_embeds] + + audio_projection_mode = kwargs.get("audio_projection_mode", "speech") + audio_set_tensor = [ + self.get_audio_features(input_embed, audio_projection_mode=audio_projection_mode) + for input_embed in input_embeds + ] + + with torch.no_grad(): + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + if "wte" in kwargs: + # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. + hidden_states = kwargs["wte"](input_ids) + else: + # otherwise, we use token embedding in pretrained mixformer from phi team + hidden_states = self.wte(input_ids) + + if len(positions.tolist()) > 0: + assert sum(audio_embed_sizes) == len( + positions + ), "please ensure the encoder outputs have the same length as defined in input_ids!" + idx = 0 + for i in range(len(audio_embed_sizes)): + cnt = audio_embed_sizes[i] + assert audio_set_tensor[i].shape[0] == 1 + hidden_states[ + positions[idx, 0], + positions[idx, 1] : positions[idx, 1] + cnt, + ] = ( + audio_set_tensor[i][0, : audio_embed_sizes[i], :] + .to(hidden_states.dtype) + .to(hidden_states.device) + ) + idx += cnt + + else: + if self.training: + # hidden_states[:, 0:img_set_tensor.shape[0]] = hidden_states[:, 0:img_set_tensor.shape[0]] + 0 * img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + hidden_states[:, 0:1] = hidden_states[ + :, 0:1 + ] + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype).to( + hidden_states.device + ) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py new file mode 100644 index 000000000000..ed65c7f7e747 --- /dev/null +++ b/vllm/model_executor/models/phi4o.py @@ -0,0 +1,1761 @@ +import itertools +import math +import os +from functools import lru_cache +import re +from typing import ( + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Tuple, + TypedDict, + Union, +) + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torchvision.transforms as T +from safetensors.torch import load_file +from transformers import PretrainedConfig +from PIL import Image + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, MultiModalConfig, LoRAConfig +from vllm.distributed import get_pp_group +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs.data import token_inputs, TokenInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE +) +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors, SequenceData +from transformers.utils import logging + +from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA +from .vision_siglip_navit import get_siglip_vision_model +from .phi3s_utils import AudioEmbedding +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 # <|endoftext10|> (see vocab.json in hf model) +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + +_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 +DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz + +DYNAMIC_HD = 16 +AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" +IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" + +SIGLIP_NAME = "siglip-so400m-patch14-448" +VISION_ENCODER_TO_PROCESSING_CONFIG = { + 'siglip-so400m-patch14-448': { + 'dynamic_hd': 16, + 'vit_image_size': 448, + 'vit_patch_size': 14, + 'token_compression_factor': 2, + }, +} +logger = logging.get_logger(__name__) +# This is a workaround to prevent text (user input) + audio + image from being used in +# the same prompt. +# It includes token ids for "/n" and tokens in added_tokens_decoder from the +# tokenizer_confg.json file. +NON_USER_INPUT_TOKENS = {198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, 200023, 200024, 200025, 200026, 200027, 200028} + +def get_max_dummy_image(ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * dynamic_hd_size + dummy_image = dummy_image_for_phi3v(vit_image_size, max_side) + return dummy_image + + +# image token length +def get_max_phi3v_image_tokens(ctx: InputContext): + dummy_image = get_max_dummy_image(ctx) + + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_num_tokens = _compute_num_image_tokens( + dummy_image, + dynamic_hd_size, + vit_image_size, + vit_patch_size, + token_compression_factor + ) + return image_num_tokens + + +# image processor +def parity_check_image_processor(): + return + import requests + from PIL import Image + url = 'https://www.ilankelman.org/stopsigns/australia.jpg' + image = Image.open(requests.get(url, stream=True).raw) + image_inputs = preprocess( + [image], dynamic_hd_size=16, vit_resolution=448, vit_patch_size=14 + ) + image_inputs['input_image_embeds'] = image_inputs['pixel_values'] + + gt_dict = torch.load("examples/parity_processor.pt") + + print('image preprocessing parity check') + for k in gt_dict: + print(f"checking {k} ...") + gt = gt_dict[k] + pt = image_inputs[k] + if isinstance(gt_dict[k], torch.Tensor): + gt = gt.cpu() + pt = pt.cpu() + error = pt - gt + print(f"max difference: {torch.max(torch.abs(error))}") + else: + print(f"pt: {pt}") + print(f"gt: {gt}") + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def _find_target_aspect_ratio(image, image_size, max_num, min_num): + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + logger.debug("target_aspect_ratio: %s", target_aspect_ratio) + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + + +def _get_padding_size(image, target_height, target_width): + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + + if ratio_width < ratio_height: + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + return padding_height, padding_width + + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, mask_size=27): + target_aspect_ratio, target_height, target_width = _find_target_aspect_ratio( + image, image_size, max_num, min_num + ) + padding_height, padding_width = _get_padding_size(image, target_height, target_width) + + # Calculate the ratio + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + else: + new_size = (int(orig_width * ratio_height), target_height) + + attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0]))) + if padding_width >= 14: + attention_mask[:, -math.floor(padding_width/14):] = 0 + if padding_height >= 14: + attention_mask[-math.floor(padding_height/14):,:] = 0 + assert attention_mask.sum() > 0, f'attention mask is empty {attention_mask}' + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f'the aspect ratio is very extreme {new_size}') + + image = T.functional.resize(image, [new_size[1], new_size[0]],) + + resized_img = T.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255]) + + return resized_img, attention_mask + +def pad_to_max_num_crops(images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if max_crops > B: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + +def pad_mask_to_max_num_crops(masks, max_crops=5): + B, H, W = masks.shape + if max_crops > B: + pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + +def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): + + # Basic settings. + img_processor = T.Compose([ + T.ToTensor(), + T.Normalize( + (0.5, 0.5, 0.5), + (0.5, 0.5, 0.5) + ), + ]) + # Dynamic HD + base_resolution = vit_resolution + images = [image.convert('RGB') for image in images] + # cover 384 and 448 resolution + mask_resolution = base_resolution // vit_patch_size + elems, image_attention_masks = [], [] + for im in images: + elem, attention_mask = dynamic_preprocess(im, max_num=dynamic_hd_size, image_size=base_resolution, mask_size=mask_resolution) + elems.append(elem) + image_attention_masks.append(attention_mask) + hd_images = [img_processor(im) for im in elems] + global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images] + shapes = [[im.size(1), im.size(2)] for im in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] + global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images] + hd_images_reshape = [im.reshape(1, 3, + h//base_resolution, + base_resolution, + w//base_resolution, + base_resolution + ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)] + attention_masks_reshape = [mask.reshape(1, + h//mask_resolution, + mask_resolution, + w//mask_resolution, + mask_resolution + ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)] + # NOTE token compression is hard coded here, and odd numbers seems to fail + downsample_attention_masks = [mask[:,0::2,0::2].reshape(1, + h//mask_resolution, + w//mask_resolution, + mask_resolution//2+mask_resolution%2, + mask_resolution//2+mask_resolution%2 + ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)] + downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks] + # NOTE hard coded number of tokens + num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks] + + hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)] + hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "pixel_values": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + # data = [returned_input_image_embeds, returned_image_sizes, returned_image_attention_mask, returned_num_img_tokens] + return data + + # return BatchFeature(data=data, tensor_type=return_tensors) + + + +class PhiOImageEncoder(nn.Module): + """Image embedding.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "") -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + # logger.info(f"create image tower {config.img_processor}") + + # layer_idx to output the img features + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get('type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + self.img_processor = get_siglip_vision_model(_flash_attn_2_enabled=True) + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L, f'position embedding size {L} is not square' + if H % 2 != 0: #and kwargs.get('image_token_compression_cls', None) is None: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H//2)**2 + self.base_feat_height_target = H + + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = True # kwargs.get('use_hd_transform', False) + self.with_learnable_separator = True # kwargs.get('with_learnable_separator', False) + self.hd_transform_order = "sub_glb" # kwargs.get('hd_transform_order', 'glb_sub') + self.freeze_img_processor = False # kwargs.get('freeze_img_processor', False) + self.crop_size = 448 # kwargs.get('crop_size', 336) + # logger.info(f'freeze_img_processor = {self.freeze_img_processor}') + + # image token compression + self.image_token_compression_cls = 'avg_pool_2d' # kwargs.get('image_token_compression_cls', None) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform, 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + # logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') + + projection_cls = "mlp" # kwargs.get('projection_cls', 'linear') + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + + self.vocab_size = config.vocab_size + self.img_features = None + + self.use_out_place_operations = False # kwargs.get('use_out_place_operations', False) + + def get_img_features(self, + img_embeds: torch.FloatTensor, + attention_mask=None) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, + output_hidden_states=True, + patch_attention_mask=attention_mask) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature + + use_token_compression = self.image_token_compression is not None + use_padding = getattr(self, 'img_processor_padding', None) is not None + if use_token_compression or use_padding: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + + if use_padding: + patch_feature = self.img_processor_padding(patch_feature) + if use_token_compression: + patch_feature = self.image_token_compression(patch_feature) + + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + + return patch_feature + + + # logger.info(f'processed img feature size = {img_feature.size()}') + raise NotImplementedError + + def forward(self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor) -> torch.FloatTensor: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + image_sizes: [[h1, w1], [h2, w2]] + image_attention_mask: num_images x num_crops x 32 x 32 + output: (num_images, num_img_tokens, hidden_size) + """ + + # eg + # pixel_values: torch.Size([1, 7, 3, 448, 448]) + # image_sizes: tensor([[ 896, 1344]], device='cuda:0') + # output: torch.Size([1, 1841, 3072]) + + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + img_sizes = image_sizes + num_images, num_crops, c, h, w = pixel_values.shape + # assert num_images == 1, "Currently only support single image" # TODO debug multi-image + bs = num_images + pixel_values = pixel_values.flatten(0, 1) + + img_features = self.get_img_features(pixel_values, + image_attention_mask.type(torch.BoolTensor).flatten(0,1).to(target_device)) + + fake_image_forward = False + select = False + hd_transform = False + + base_feat_height_target = self.base_feat_height_target + base_resolution = self.crop_size + base_feat_height_reduction = self.base_feat_height_reduction + + base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform' + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // base_resolution + w = w // base_resolution + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() + temp_glb_GN = self.sub_GN.repeat(1, H//base_feat_height_reduction, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() + sub_img = sub_img.reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction, -1).permute(0,1,3,2,4,5).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C) + + if image_attention_mask is not None and len(image_attention_mask) > 0: + reshaped_image_attention_mask = image_attention_mask[_bs,1:B_+1,0::2,0::2].reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction).permute(0,1,3,2,4).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction) + useful_height = int(reshaped_image_attention_mask[0,:,0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0,0,:].sum().item()) + sub_img = sub_img[:,:useful_height, :useful_width] + temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction + else: + temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1) + temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented') + + #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' + output_len.append(temp_len) + + num_img_tokens = output_len + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + img_set_tensor.append(img_feature_proj) + + return img_set_tensor + + +class Phi3SAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Tuple[NestedTensors] + """Shape: `((batch_size, num_audios, 80, M), )""" + + +class Phi3SAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + +Phi3SAudioInputs = Union[Phi3SAudioFeatureInputs, Phi3SAudioEmbeddingInputs] + + +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negtive" + assert ( + fmin < fmax <= sample_rate / 2 + ), "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses trianges in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class LogFbankProcessor: + def __init__(self): + + self._eightk_method = "fillzero" + self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T + + self._hamming400 = np.hamming(400) # for 16k audio + self._hamming200 = np.hamming(200) # for 8k audio + + def extract_spectrogram(self, wav, fs): + """Extract spectrogram features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + if wav.ndim > 1: + wav = np.squeeze(wav) + + # by default, we extract the mean if stereo + if len(wav.shape) == 2: + wav = wav.mean(1) + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + fs = 16000 + elif 8000 < fs < 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + fs = 8000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + if fs == 8000: + if self._eightk_method == "resample": + # Input audio is 8 kHz. Convert to 16 kHz before feature + # extraction + wav = scipy.signal.resample_poly(wav, 2, 1) + fs = 16000 + # Do nothing here for fillzero method + elif fs != 16000: + # Input audio is not a supported sample rate. + raise RuntimeError( + f"Input data using an unsupported sample rate: {fs}" + ) + + preemphasis = 0.97 + + if fs == 8000: + n_fft = 256 + win_length = 200 + hop_length = 80 + fft_window = self._hamming200 + elif fs == 16000: + n_fft = 512 + win_length = 400 + hop_length = 160 + fft_window = self._hamming400 + + # Spec 1: SpeechLib cut remaining sample insufficient for a hop + n_batch = (wav.shape[0] - win_length) // hop_length + 1 + # Here we don't use stride_tricks since the input array may not satisfy + # memory layout requirement and we need writeable output + # Here we only use list of views before copy to desination + # so it is more efficient than broadcasting + y_frames = np.array( + [ + wav[_stride : _stride + win_length] + for _stride in range(0, hop_length * n_batch, hop_length) + ], + dtype=np.float32, + ) + + # Spec 2: SpeechLib applies preemphasis within each batch + y_frames_prev = np.roll(y_frames, 1, axis=1) + y_frames_prev[:, 0] = y_frames_prev[:, 1] + y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + + S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype( + np.complex64 + ) + + if fs == 8000: + # Need to pad the output to look like 16 kHz data but with zeros in + # the 4 to 8 kHz bins. + frames, bins = S.shape + padarray = np.zeros((frames, bins)) + S = np.concatenate( + (S[:, 0:-1], padarray), axis=1 + ) # Nyquist bin gets set to zero + + spec = np.abs(S).astype(np.float32) + return spec + + def extract_features(self, wav, fs): + """Extract log filterbank features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + spec = self.extract_spectrogram(wav, fs) + spec_power = spec**2 + + fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) + log_fbank = np.log(fbank_power).astype(np.float32) + + return log_fbank + + +@lru_cache +def audio_feature_extractor() -> LogFbankProcessor: + # Creates an instance of the audio processor, needed to extract the + # the audio featues from the sound file + # LRU cache ensures that we only make one copy + return LogFbankProcessor() + + +def _compute_num_image_tokens( + image, dynamic_hd_size, vit_image_size, vit_patch_size, token_compression_factor +): + """ + compute the number of tokens an image is expected to take up considering the image encoder + architecture and exclude output features containing only padding pixels + + for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map + NOTE right now, Phi-O uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, "vit_image_size must be divisible by vit_patch_size" + assert vit_image_size // vit_patch_size % token_compression_factor == 0, "vit_image_size // vit_patch_size must be divisible by token_compression_factor" + + target_aspect_ratio, target_height, target_width = ( + _find_target_aspect_ratio( + image, vit_image_size, dynamic_hd_size, min_num=1 + ) + ) + assert target_aspect_ratio[0] * vit_image_size == target_width, f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + assert target_aspect_ratio[1] * vit_image_size == target_height, f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + assert ( + target_height % vit_image_size == 0 + and target_width % vit_image_size == 0 + ) + + padding_height, padding_width = _get_padding_size( + image, target_height, target_width + ) + assert padding_width == 0 or padding_height == 0, "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size + ) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size + ) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height # FIXME bug: 1504, should be 1536 + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 + num_sep_tokens = 1 + num_global_image_newline_tokens = vit_feature_size // token_compression_factor + + return ( + num_global_image_tokens + + num_sep_tokens + + num_hd_patch_tokens + + num_hd_newline_tokens + + num_global_image_newline_tokens + ) + + +def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: + """ + Compute the output size of the `extract_features` method. + + Args: + wav_length (int): Length of the input waveform in samples. + fs (int): Sampling rate of the waveform, either 16000 or 8000. + + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav_length //= fs // 16000 + fs = 16000 + elif 8000 <= fs < 16000: + # We'll resample to 16K from 8K + wav_length *= 2 + fs = 16000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + mel_bins = 80 # Number of mel filterbank bins + + # Calculate number of frames (T) + T = (wav_length - win_length) // hop_length + 1 + if T < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) and mel bins (D) + return T, mel_bins + + +def _get_audio_embed_sizes(audios, ctx: InputContext): + audio_embed_sizes = [] + for audio in audios: + audio_data, sf = audio + audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) + audio_embed_size = _compute_audio_embed_size( + ctx.get_hf_config(), audio_frames + ) + audio_embed_sizes.append(audio_embed_size) + return audio_embed_sizes + + +def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): + if len(audios) == 0: + return {} + + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + + # The following logic will search for `<|audio_{idx}|>` tokens and + # insert the placeholder audio tokens that will be overwritten by the + # embedding in the audio tower + audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) + audio_ids = [int(audio_id) for audio_id in audio_ids] + assert len(audio_ids) == len(audio_embed_sizes), "Number of audio tokens and audio features do not match" + assert tuple(audio_ids) == tuple( + range(1, len(audio_ids) + 1) + ), "Audio ids are not in order!" + audio_id_to_input_ids = { + f"<|audio_{audio_id}|>": [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) + } + + return audio_id_to_input_ids + + +def _count_image_tokens(images, ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_token_counts = [ + _compute_num_image_tokens( + image, + dynamic_hd_size, + vit_image_size, + vit_patch_size, + token_compression_factor + ) + for image in images + ] + return image_token_counts + + +def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): + if len(images) == 0: + return {} + + image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) + image_ids = [int(image_id) for image_id in image_ids] + assert len(image_ids) == len(set(image_ids)), "Duplicate image tokens in prompt" + assert len(images) == len(image_ids), "Number of images and image tokens in prompt do not match" + + # NOTE the following assertion is not strictly necessary + assert tuple(image_ids) == tuple(range(1, len(image_ids) + 1)), "Image ids are not in order" + + image_token_counts = _count_image_tokens(images, ctx) + image_id_to_input_ids = { + f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens + for image_id, num_tokens in zip(image_ids, image_token_counts) + } + return image_id_to_input_ids + + +def input_processor_for_phio( + ctx: InputContext, inputs: DecoderOnlyInputs +) -> TokenInputs: + """ + Implements the input processor, which transforms the input prompt ids + to include the audio placeholder token. This will become the `input_ids` + in `forward` for the model. + + Args: + ctx (InputContext): Input context. + inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) to process. + + Returns: + TokenInputs: Processed inputs + """ + multi_modal_data = inputs.get("multi_modal_data") + # Check if audio is being used as a modality + if (multi_modal_data is None + or ("audio" not in multi_modal_data + and "image" not in multi_modal_data)): + # pure text input + return inputs + + prompt_str = inputs.get("prompt") + prompt_token_ids = inputs.get("prompt_token_ids") + # for offline_inference, we will get str input and we parse MM special tokens from it + # (ignore prompt_token_ids) + # for OAI server, we will get prompt_token_ids, where MM special tokens are already parsed + + if 'audio' in multi_modal_data: + audios = multi_modal_data["audio"] + + if not isinstance(audios, list): + audios = [audios] + if prompt_str is not None: + audio_id_to_input_ids = _get_audio_id_to_input_ids(audios, ctx, prompt_str=prompt_str) + audio_embed_sizes = [] + elif prompt_token_ids is not None: + audio_id_to_input_ids = {} + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + else: + audio_id_to_input_ids = {} + audio_embed_sizes = [] + + if 'image' in multi_modal_data: + # PIL Image or list of PIL Images + images = multi_modal_data["image"] + if not isinstance(images, list): + images = [images] + if prompt_str is not None: + image_id_to_input_ids = _get_image_id_to_input_ids(images, prompt_str, ctx) + image_token_counts = [] + elif prompt_token_ids is not None: + image_id_to_input_ids = {} + image_token_counts = _count_image_tokens(images, ctx) + else: + image_id_to_input_ids = {} + image_token_counts = [] + + # Handle the case where the prompt is a string and we need to manually tokenize it. + # In this case, the `audio_id_to_input_ids` dict will be mapping from an audio placeholder + # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the given audio length. + if prompt_str: + pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" + prompt_chunk_strings = re.split(pattern, prompt_str) + prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] + + # Create the new input_ids with the placholder image and audio tokens inserted + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + input_ids = [] + has_imag, has_audio, has_user_text_input = False, False, False + for prompt_chunk_string in prompt_chunk_strings: + if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) + has_imag = True + elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) + has_audio = True + else: + curr_token_ids = tokenizer(prompt_chunk_string).input_ids + if not has_user_text_input: + for token_id in curr_token_ids: + if token_id not in NON_USER_INPUT_TOKENS: + has_user_text_input = True + break + input_ids.extend(curr_token_ids) + if has_audio and has_imag and has_user_text_input: + raise ValueError("PhiOForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # Handle the case where the prompt is already tokenized + else: + assert prompt_token_ids is not None, "If string prompt isn't provided, prompt_token_ids must be" + + i = 0 + input_ids = prompt_token_ids + img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 # only needed for later assertion + image_token_count_iter = iter(image_token_counts) + audio_embed_size_iter = iter(audio_embed_sizes) + while i < len(input_ids): + token_id = input_ids[i] + if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: + token_count = next(audio_embed_size_iter) + audio_cnt += 1 + elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: + token_count = next(image_token_count_iter) + img_cnt += 1 + else: + user_text_input_cnt += 1 if token_id not in NON_USER_INPUT_TOKENS else 0 + i += 1 + continue + tokens = [token_id] * token_count + input_ids = input_ids[:i] + tokens + input_ids[i + 1:] + i += token_count + + if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: + raise ValueError("PhiOForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # If the below assertion fails, it might be that input pure-text + # messages contain image/audio special tokens literally + # (<|endoftext10|>, <|endoftext11|>). + assert ( + img_cnt == len(image_token_counts) + ), ( + f"Number of image tokens in prompt_token_ids ({img_cnt}) " + f"does not match number of images ({len(image_token_counts)})" + ) + assert ( + audio_cnt == len(audio_embed_sizes) + ), ( + f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " + f"does not match number of audios ({len(audio_embed_sizes)})" + ) + + # NOTE: Create a defensive copy of the original inputs + return token_inputs( + prompt_token_ids=input_ids, + prompt=prompt_str, + multi_modal_data=multi_modal_data, + ) + + +def _compute_audio_embed_size(hf_config, audio_frames): + compression_rate = hf_config.embd_layer['audio_embd_layer']['compression_rate'] + # TODO: update this hard-coded value? + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate + + result = integer if remainder == 0 else integer + 1 + + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + result = integer if remainder == 0 else integer + 1 # qformer compression + + return result + + +def get_max_phi3s_audio_tokens(ctx: InputContext): + # TODO + return 10000 + # return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + + +def dummy_audio_for_phi3s(audio_count: int) -> dict: + """ + Create dummy audio data for the Phi-3.5-Speech model, which is used for profiling. + + Args: + audio_count (int): Number of audio samples. + + Returns: + dict: Dummy audio data. + """ + dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE,), 0.0) + return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count + + +def dummy_image_for_phi3v(width: int, height: int): + image = Image.new('RGB', (width, height), color='black') + return image + + +def dummy_data_for_phi3s( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +) -> Tuple: + """ + Create dummy sequence (input_ids) and audio data for the Phi-3.5-Speech model, which is used for + profiling. + + In this case, the sequence data is a bunch of 0s with a number of audio tokens that correspond + to the audio embed size of the _AUDIO_MAX_SOUNDFILE_SIZE. + + Args: + ctx (InputContext): Input context. + seq_len (int): Length of the sequence. + mm_counts (Mapping[str, int]): Multi-modal counts. + + Returns: + Tuple: Dummy sequence data and dummy audio data. + """ + audio_count = mm_counts["audio"] + audio_frames, _ = compute_logfbank_output_size( + _AUDIO_MAX_SOUNDFILE_SIZE, DUMMY_SAMPLING_FREQUENCY + ) + audio_feature_size = _compute_audio_embed_size( + ctx.get_hf_config(), audio_frames + ) + + image_count = mm_counts["image"] + dummy_image = get_max_dummy_image(ctx) + max_image_tokens = get_max_phi3v_image_tokens(ctx) + total_image_tokens = image_count * max_image_tokens + + if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: + raise RuntimeError( + f"Phi3O cannot process {audio_count} audios and {image_count} images in a prompt," + f"please increase max_model_len to be at larger than {audio_feature_size * audio_count + total_image_tokens}" + " or reduce audio/image limit by --limit-mm-per-prompt.") + + if audio_feature_size * audio_count > total_image_tokens: + seq_data = SequenceData.from_prompt_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), + (0, seq_len - audio_feature_size * audio_count), + ) + mm_data = { + "audio": dummy_audio_for_phi3s(audio_count), + } + else: + seq_data = SequenceData.from_prompt_token_counts( + (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), + (0, seq_len - total_image_tokens), + ) + mm_data = { + "image": [dummy_image] * image_count, + } + return seq_data, mm_data + + +def input_mapper_for_phi3s(ctx: InputContext, data: object) -> MultiModalInputs: + """ + This function is used to create the MultiModalInputs for the Phi-3.5-Speech model. + Specifically, for audio, we extract the audio features from the sound file and create + pairs of audio features and audio embed lengths (the latter of which is used to repeat + the audio placeholder token in the input prompt IDs). + These pairs are used, downstream, in `_audio_features_to_embeddings` + (via `_process_audio_input`). + + Note that the incoming audio data (each entry in `data`) is a tuple of the audio data + and the sampling frequency (e.g. from soundfile.read). + + Args: + ctx (InputContext): Input context. + data (object): Audio data. + + Returns: + MultiModalInputs: Multi-modal inputs. + """ + if not isinstance(data, list): + data = [data] + + if len(data) == 0: + return MultiModalInputs() + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}" + ) + + audio, sf = audio_input + feature_extractor = audio_feature_extractor() + single_audio_features = feature_extractor.extract_features(audio, sf) + feat_stride = ( + 1 + if not hasattr(feature_extractor, "stride") + else feature_extractor.stride + ) + audio_frames = len(single_audio_features) * feat_stride + single_audio_embed_size = _compute_audio_embed_size( + ctx.get_hf_config(), audio_frames + ) + single_audio_feature_audio_len_pair = ( + single_audio_features, + [single_audio_embed_size], + ) + audio_features.append(single_audio_feature_audio_len_pair) + return MultiModalInputs({"audio_features": audio_features}) + + +def input_mapper_for_phi3v(ctx: InputContext, data: object): + # data: list of PIL images + # assert isinstance(data, list), "Data must be a list of PIL images" + if not isinstance(data, list): + data = [data] + if len(data) == 0: + return MultiModalInputs() + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + + image_input_dict = preprocess( + data, dynamic_hd_size, vit_image_size, vit_patch_size + ) + return MultiModalInputs({ + "pixel_values": image_input_dict["pixel_values"], + "image_sizes": image_input_dict["image_sizes"], + "image_attention_mask": image_input_dict["image_attention_mask"], + "num_img_tokens": image_input_dict["num_img_tokens"], + }) + + +def cat_with_pad(tensors, dim, padding_value=0): + """ + cat along dim, while pad to max for all other dims + """ + ndim = tensors[0].dim() + assert all(t.dim() == ndim for t in tensors[1:]), "All tensors must have the same number of dimensions" + + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) + + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) + + output[slices] = t + index += t.shape[dim] + + return output + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi3s) +@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi3v) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_phi3s_audio_tokens +) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "image", get_max_phi3v_image_tokens +) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3s) # TODO dummy data for vision? +@INPUT_REGISTRY.register_input_processor(input_processor_for_phio) +class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP): + """ + Implements the Phi-3.5-Omni model in VLLM. + + Args: + config (PretrainedConfig): Pretrained model configuration. + multimodal_config (MultiModalConfig): Multi-modal configuration. + cache_config (Optional[CacheConfig]): Cache configuration. + quant_config (Optional[QuantizationConfig]): Quantization configuration. + """ + # LoRA specific attributes + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } + # QKVParallelLinear, RowParallelLinear, MergedColumnParallelLinear, RowParallelLinear + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj" + ] + # PhiOForCausalLM does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + self.multimodal_config = multimodal_config + assert self.multimodal_config, "multimodal_config is required" + self.quant_config = quant_config + self.lora_config = lora_config + + # parity check for image processor + parity_check_image_processor() + + self.vision_encoder = PhiOImageEncoder( + config, + quant_config, + prefix="model.vision_embed_tokens", + model_dir=config._name_or_path) + + + if isinstance(self.config.embd_layer["audio_embd_layer"], dict): + embedding_config = { + "embedding_cls": self.config.embd_layer["audio_embd_layer"][ + "embedding_cls" + ], + **self.config.embd_layer["audio_embd_layer"], + } + else: + embedding_config = { + "embedding_cls": self.config.embd_layer["embedding_cls"] + } + + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + # self.language_model = LlamaForCausalLM( + # config, cache_config, quant_config + # ) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _audio_features_to_embeddings( + self, + input_ids: torch.Tensor, + input_features: List[torch.Tensor], + audio_input_sizes: torch.Tensor, + audio_projection_mode: str, + ) -> torch.Tensor: + """ + Convert audio features to embeddings, which are used as input to the model (via + `inputs_embeds`). + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case). + input_features (list[torch.Tensor]): Input features (the audio embeddings). + audio_input_sizes (list[torch.Tensor]): Audio input sizes (the audio embed lengths to use for + padding the audio placeholder token in the input prompt IDs). + """ + # The audio projection can either be a single linear or Sequential, so handle + # both cases + if isinstance(self.embed_tokens_extend.audio_projection, nn.Sequential): + target_dtype = self.embed_tokens_extend.audio_projection[ + 0 + ].bias.dtype + else: + target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype + + audio_input = [ + input.unsqueeze(0).to(target_dtype) for input in input_features + ] + kwargs = {"wte": self.model.embed_tokens, 'audio_projection_mode': audio_projection_mode} + audio_embeddings = self.embed_tokens_extend( + input_ids, audio_input, audio_input_sizes, **kwargs + ) + audio_embeddings = audio_embeddings.to(target_dtype) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> Optional[Phi3SAudioInputs]: + """ + Parse and validate the audio input to the model. This handles both audio features and + audio embeddings, but only the former is used for now. + + Args: + kwargs (object): Keyword arguments. + + Returns: + Optional[Phi3SAudioInputs]: Parsed and validated audio inputs. + """ + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of audio features. " + f"Got type: {type(audio_features)}" + ) + + return Phi3SAudioFeatureInputs( + type="audio_features", data=audio_features + ) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}" + ) + + return Phi3SAudioEmbeddingInputs( + type="audio_embeds", data=audio_embeds + ) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, input_ids: torch.Tensor, audio_input: Phi3SAudioInputs, audio_projection_mode: str + ) -> NestedTensors: + """ + Create the audio embeddings from the audio input, where the audio input is pairs of + audio features and audio embed lengths. The audio input is created by + `input_mapper_for_phi3s`. + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token + replication). + audio_input (Phi3SAudioInputs): Audio input. + + Returns: + NestedTensors: Audio embeddings + """ + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + audio_feature = [i[0] for j in audio_features for i in j] + audio_feature_len = [i[1].item() for j in audio_features for i in j] + # Add the batch dim via `squeeze` + + return self._audio_features_to_embeddings( + input_ids.unsqueeze(0), + audio_feature, + audio_feature_len, + audio_projection_mode, + ).squeeze(0) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Dict]: + pixel_values: Optional[Dict] = kwargs.get("pixel_values") + if pixel_values is None: + return None + + image_sizes = kwargs.get("image_sizes") + image_attention_mask = kwargs.get("image_attention_mask") + num_img_tokens = kwargs.get("num_img_tokens") + assert image_sizes is not None and image_attention_mask is not None and num_img_tokens is not None, "Missing image inputs" + + if isinstance(pixel_values, list): + assert pixel_values[0].dim() == 5, "Incorrect image inputs" + # list len is batch_size + # each tensor has dimension: num_img_per_example, num_hd_patches, channels, height, width + # need to pad along num_hd_patches + # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w + pixel_values = cat_with_pad(pixel_values, dim=0) + elif isinstance(pixel_values, torch.Tensor): + # dimension: batch_size, num_img_per_example, num_hd_patches, channels, height, width + # we flatten first 2 dims to make it a single large batch for SigLIP Encoder + assert pixel_values.dim() == 6, "Incorrect image inputs" + pixel_values = pixel_values.flatten(0, 1) + else: + raise ValueError("Incorrect pixel_values inputs") + + if isinstance(image_attention_mask, list): + image_attention_mask = cat_with_pad(image_attention_mask, dim=0) + elif isinstance(image_attention_mask, torch.Tensor): + image_attention_mask = image_attention_mask.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(image_sizes, list): + image_sizes = torch.cat(image_sizes, dim=0) + elif isinstance(image_sizes, torch.Tensor): + image_sizes = image_sizes.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(num_img_tokens, list): + num_img_tokens = [n for num_tensor in num_img_tokens for n in num_tensor.tolist()] + elif isinstance(num_img_tokens, torch.Tensor): + num_img_tokens = num_img_tokens.flatten(0, 1).tolist() + else: + raise ValueError("Incorrect image_attention_mask inputs") + + return { + 'pixel_values': pixel_values, + 'image_sizes': image_sizes, + 'image_attention_mask': image_attention_mask, + 'num_img_tokens': num_img_tokens, + } + + def merge_image_features_to_inputs_embeds( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_set_tensors: List[torch.Tensor], + ): + position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=True) + + assert all([t.shape[0] == 1 for t in image_set_tensors]), 'img_set_tensor should have shape (1, N_tokens, C)' + # Shape: (merged_N_tokens, C) + image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) + image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(inputs_embeds.device) + merged_embeds = inputs_embeds.index_put( + indices=position_tuple, + values=image_set_tensor, + accumulate=False, + ) + return merged_embeds + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + """ + Load in the weights for VLLM. + + NOTE: I highly recommend avoiding the VLLM WeightMapper/Autoloader approach + since it's needlessly complex. + + Args: + weights (Iterable[Tuple[str, torch.Tensor]]): Weights to load (usualy from a + Hugging Face model). + """ + weights = {name: weight for name, weight in weights} + adjusted_weights = {} + + for name, weight in weights.items(): + # NOTE vision-speech tasks use a seperate projection layer + audio_proj_4v = "model.embed_tokens_extend.audio_embed.audio_projection.vision" + if name.startswith(audio_proj_4v): + name = name.replace(audio_proj_4v, "embed_tokens_extend.audio_projection_for_vision") + + name = ( + # name.replace("model.embed_tokens.", "embed_tokens.") + name.replace( + "model.embed_tokens_extend.audio_embed.audio_projection.speech.", + "embed_tokens_extend.audio_projection.", + ) + .replace( + "model.embed_tokens_extend.audio_embed.", + "embed_tokens_extend.", + ) + .replace("model.embed_tokens_extend.image_embed.", "vision_encoder.") + ) + # NOTE: this is deal with LoRA injection, where `base_layer` remains as the original + # layer in the model + if name.endswith(".base_layer.weight"): + name = name.replace(".base_layer.weight", ".weight") + adjusted_weights[name] = weight + + # if name == "model.embed_tokens.weight": + # adjusted_weights["embed_tokens.weight"] = ( + # weight + # ) + + missing_keys, unexpected_keys = self.load_state_dict( + adjusted_weights, strict=False + ) + logger.debug("--------------- missing keys -----------------") + for key in missing_keys: + logger.debug(key) + logger.debug("--------------- unexpected keys ---------------") + for key in unexpected_keys: + logger.debug(key) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> torch.Tensor: + """ + Run the forward pass of the model. + + Args: + input_ids (torch.Tensor): Input IDs. + positions (torch.Tensor): Positions (handled by VLLM) + kv_caches (List[torch.Tensor]): Key-value caches (handled by VLLM) + attn_metadata (AttentionMetadata): Attention metadata (handled by VLLM) + intermediate_tensors (Optional[IntermediateTensors]): Intermediate tensors + (handled by VLLM) + kwargs (object): Keyword arguments. NOTE: this should contain the audio/MM input. + + Returns: + torch.Tensor: Hidden states / model output. + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # Each entry in this is a pair of audio_features and audio_embed lengths + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_inputs = self._parse_and_validate_image_input(**kwargs) + + has_audio = audio_input is not None + has_image = image_inputs is not None + + if has_audio: + audio_projection_mode = 'vision' if has_image else 'speech' + inputs_embeds = self._process_audio_input( + input_ids, audio_input, audio_projection_mode + ) + + if has_image: + dtype = self.vision_encoder.img_processor.embeddings.patch_embedding.weight.dtype + pixel_values = image_inputs['pixel_values'].to(dtype) + image_sizes = image_inputs['image_sizes'] + image_attention_mask = image_inputs['image_attention_mask'] + image_set_tensors = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask + ) + if not has_audio: + inputs_embeds = self.model.embed_tokens(input_ids) + + inputs_embeds = self.merge_image_features_to_inputs_embeds( + input_ids, inputs_embeds, image_set_tensors + ) + + if has_image or has_audio: + # multi-modal input, we have set inputs_embeds properly in previous steps + input_ids = None + else: + # text-only, we keep using original input_ids + inputs_embeds = None + + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """ + Boilerplate method for computing logits (needed by the sampler). + + Args: + hidden_states (torch.Tensor): Hidden states. + sampling_metadata (SamplingMetadata): Sampling metadata. + + Returns: + Optional[torch.Tensor]: Logits. + """ + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """ + Boilerplate method for sampling by VLLM. + + Args: + logits (torch.Tensor): Logits. + sampling_metadata (SamplingMetadata): Sampling metadata. + + Returns: + Optional[SamplerOutput]: Sampler output. + """ + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4551d81e8a5d..84e6e4126d29 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -181,6 +181,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), + "PhiOForCausalLM": ("phi3o", "PhiOForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py new file mode 100644 index 000000000000..924836eee239 --- /dev/null +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -0,0 +1,1722 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Siglip model configuration""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + Example: + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + ```python + >>> from transformers import SiglipConfig, SiglipModel + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union +from safetensors.torch import load_model, save_model + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h, nb_patches_h) + fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w, nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ( + SiglipAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.tensor(0.0) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask=None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): + siglip_vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + model_config = SiglipVisionConfig( + **siglip_vision_config, + _flash_attn_2_enabled=_flash_attn_2_enabled, + **kwargs + ) + + vision_model = SiglipVisionModel(model_config).vision_model + + return vision_model From d0fa70fcfae25cf647612229d1804c1d0c0aff42 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 5 Feb 2025 16:44:03 -0800 Subject: [PATCH 02/27] Fix error related to interface changes from latest main Signed-off-by: Congcong Chen --- examples/offline_inference_phi3o.py | 12 +++---- vllm/model_executor/models/phi4o.py | 45 +++++++++++--------------- vllm/model_executor/models/registry.py | 2 +- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/examples/offline_inference_phi3o.py b/examples/offline_inference_phi3o.py index 668834168049..3488ff3d80f3 100644 --- a/examples/offline_inference_phi3o.py +++ b/examples/offline_inference_phi3o.py @@ -105,7 +105,7 @@ def main_with_lora_speech(args: dict, activate_lora_request=True) -> None: # ) sampling_params = SamplingParams( temperature=0, - max_tokens=1200, + max_tokens=200, ) outputs = llm.generate(generate_args, sampling_params=sampling_params, lora_request= [LoRARequest("speech_adapter", 3, args.speech_lora_path)] if activate_lora_request else None) @@ -460,7 +460,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--model-path", "-p", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/merged/speech-merged-lora-and-base_model-from-hf-unified-model", + default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/MoE_weijian_phio-final_2-newtxtsftmore-hf", help="Path to the (HuggingFace) model checkpoint.", ) @@ -468,7 +468,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--vision-lora-path", "-v", type=str, - default="/modelblob/users/weijianxu/phi-o/vision-speech-merged-pretraining/official_run/Phio-SFT-long-001-DPO-002/merged-vision-mframerc1.2abl1-2.1k-speech-shadow50k-postsr002-posttrain-vision12k-trial2/vllm_lora/vision-lora-only-from-hf-unified-model/", + default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/vision-lora-only-from-hf-unified-model", help="Path to the (HuggingFace) vision lora model checkpoint.", ) @@ -476,8 +476,8 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--speech-lora-path", "-s", type=str, - default="/modelblob/users/weijianxu/phi-o/vision-speech-merged-pretraining/official_run/Phio-SFT-long-001-DPO-002/merged-vision-mframerc1.2abl1-2.1k-speech-shadow50k-postsr002-posttrain-vision12k-trial2/vllm_lora/speech-lora-only-from-hf-unified-model/", - help="Path to the (HuggingFace) vision lora model checkpoint.", + default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/speech-lora-only-from-hf-unified-model", + help="Path to the (HuggingFace) speech lora model checkpoint.", ) parser.add_argument( @@ -485,7 +485,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "-w", type=str, default= - "30s_test_6.wav", + "/scratch/turing_westus3_prm_data/users/congcongchen/30s_test_6.wav", help="Path to the audio file.", ) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index ed65c7f7e747..76dc41b2f5cd 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -25,9 +25,9 @@ from PIL import Image from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, MultiModalConfig, LoRAConfig +from vllm.config import VllmConfig from vllm.distributed import get_pp_group -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext, DummyData from vllm.inputs.data import token_inputs, TokenInputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -37,8 +37,8 @@ ) from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs) +from vllm.multimodal.inputs import MultiModalInputs, NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from transformers.utils import logging @@ -1141,7 +1141,7 @@ def dummy_image_for_phi3v(width: int, height: int): def dummy_data_for_phi3s( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> Tuple: +) -> DummyData: """ Create dummy sequence (input_ids) and audio data for the Phi-3.5-Speech model, which is used for profiling. @@ -1192,7 +1192,7 @@ def dummy_data_for_phi3s( mm_data = { "image": [dummy_image] * image_count, } - return seq_data, mm_data + return DummyData(seq_data, mm_data) def input_mapper_for_phi3s(ctx: InputContext, data: object) -> MultiModalInputs: @@ -1335,21 +1335,19 @@ class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + assert multimodal_config, "multimodal_config is required" + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.multimodal_config = multimodal_config - assert self.multimodal_config, "multimodal_config is required" self.quant_config = quant_config self.lora_config = lora_config - + # parity check for image processor parity_check_image_processor() @@ -1360,12 +1358,12 @@ def __init__( model_dir=config._name_or_path) - if isinstance(self.config.embd_layer["audio_embd_layer"], dict): + if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { - "embedding_cls": self.config.embd_layer["audio_embd_layer"][ + "embedding_cls": config.embd_layer["audio_embd_layer"][ "embedding_cls" ], - **self.config.embd_layer["audio_embd_layer"], + **config.embd_layer["audio_embd_layer"], } else: embedding_config = { @@ -1373,14 +1371,7 @@ def __init__( } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - # self.language_model = LlamaForCausalLM( - # config, cache_config, quant_config - # ) - self.model = LlamaModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = LlamaModel(vllm_config=vllm_config, prefix="model") if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 84e6e4126d29..76114f471ad3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -181,7 +181,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), - "PhiOForCausalLM": ("phi3o", "PhiOForCausalLM"), + "PhiOForCausalLM": ("phi4o", "PhiOForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 From 85ef077a8539466d057a48ed9bc42fd17ad7757d Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 11:01:30 -0800 Subject: [PATCH 03/27] rename and clean up code Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4o.py | 34 +------------------ .../models/{phi3s_utils.py => phi4o_utils.py} | 0 2 files changed, 1 insertion(+), 33 deletions(-) rename vllm/model_executor/models/{phi3s_utils.py => phi4o_utils.py} (100%) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index 76dc41b2f5cd..c0a2f3c5056f 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -45,7 +45,7 @@ from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .vision_siglip_navit import get_siglip_vision_model -from .phi3s_utils import AudioEmbedding +from .phi4o_utils import AudioEmbedding from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -114,35 +114,6 @@ def get_max_phi3v_image_tokens(ctx: InputContext): return image_num_tokens -# image processor -def parity_check_image_processor(): - return - import requests - from PIL import Image - url = 'https://www.ilankelman.org/stopsigns/australia.jpg' - image = Image.open(requests.get(url, stream=True).raw) - image_inputs = preprocess( - [image], dynamic_hd_size=16, vit_resolution=448, vit_patch_size=14 - ) - image_inputs['input_image_embeds'] = image_inputs['pixel_values'] - - gt_dict = torch.load("examples/parity_processor.pt") - - print('image preprocessing parity check') - for k in gt_dict: - print(f"checking {k} ...") - gt = gt_dict[k] - pt = image_inputs[k] - if isinstance(gt_dict[k], torch.Tensor): - gt = gt.cpu() - pt = pt.cpu() - error = pt - gt - print(f"max difference: {torch.max(torch.abs(error))}") - else: - print(f"pt: {pt}") - print(f"gt: {gt}") - - def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) @@ -1347,9 +1318,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.quant_config = quant_config self.lora_config = lora_config - - # parity check for image processor - parity_check_image_processor() self.vision_encoder = PhiOImageEncoder( config, diff --git a/vllm/model_executor/models/phi3s_utils.py b/vllm/model_executor/models/phi4o_utils.py similarity index 100% rename from vllm/model_executor/models/phi3s_utils.py rename to vllm/model_executor/models/phi4o_utils.py From d8f40d8b100175533befbcebcd9caf7e75c55162 Mon Sep 17 00:00:00 2001 From: Jacob Platin Date: Fri, 7 Feb 2025 12:29:51 -0800 Subject: [PATCH 04/27] Minor clean-up Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4o.py | 105 ++++++++++++---------- vllm/model_executor/models/phi4o_utils.py | 1 - 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index c0a2f3c5056f..2fcfe7b930db 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -1,6 +1,4 @@ -import itertools import math -import os from functools import lru_cache import re from typing import ( @@ -20,7 +18,6 @@ import torch import torch.nn as nn import torchvision.transforms as T -from safetensors.torch import load_file from transformers import PretrainedConfig from PIL import Image @@ -33,11 +30,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE + ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE ) -from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel +from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs) +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalInputs, NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData @@ -46,8 +43,7 @@ from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .vision_siglip_navit import get_siglip_vision_model from .phi4o_utils import AudioEmbedding -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import PPMissingLayer _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 # <|endoftext10|> (see vocab.json in hf model) @@ -538,19 +534,19 @@ def forward(self, return img_set_tensor -class Phi3SAudioFeatureInputs(TypedDict): +class Phi4OAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: Tuple[NestedTensors] """Shape: `((batch_size, num_audios, 80, M), )""" -class Phi3SAudioEmbeddingInputs(TypedDict): +class Phi4OAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -Phi3SAudioInputs = Union[Phi3SAudioFeatureInputs, Phi3SAudioEmbeddingInputs] +Phi4OAudioInputs = Union[Phi4OAudioFeatureInputs, Phi4OAudioEmbeddingInputs] def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): @@ -850,6 +846,17 @@ def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: def _get_audio_embed_sizes(audios, ctx: InputContext): + """ + Get the audio embedding sizes for each audio file. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + + Returns: + List[int]: List of audio embedding sizes. + """ audio_embed_sizes = [] for audio in audios: audio_data, sf = audio @@ -862,14 +869,25 @@ def _get_audio_embed_sizes(audios, ctx: InputContext): def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): + """ + The following will search for `<|audio_{idx}|>` tokens and + return a mapping of audio placeholder tokens to audio placeholder token ids + based on the size of the audio embeddings. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + prompt_str (str): The prompt string. + + Returns: + Dict[str, List[int]]: Mapping of audio placeholder tokens to audio placeholder token ids. + + """ if len(audios) == 0: return {} audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - - # The following logic will search for `<|audio_{idx}|>` tokens and - # insert the placeholder audio tokens that will be overwritten by the - # embedding in the audio tower audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) audio_ids = [int(audio_id) for audio_id in audio_ids] assert len(audio_ids) == len(audio_embed_sizes), "Number of audio tokens and audio features do not match" @@ -944,11 +962,10 @@ def input_processor_for_phio( TokenInputs: Processed inputs """ multi_modal_data = inputs.get("multi_modal_data") - # Check if audio is being used as a modality if (multi_modal_data is None or ("audio" not in multi_modal_data and "image" not in multi_modal_data)): - # pure text input + # pure text input, so no need to do pre-processing return inputs prompt_str = inputs.get("prompt") @@ -1070,8 +1087,11 @@ def input_processor_for_phio( def _compute_audio_embed_size(hf_config, audio_frames): + """ + Compute the audio embedding size based on the audio frames and compression rate. + """ compression_rate = hf_config.embd_layer['audio_embd_layer']['compression_rate'] - # TODO: update this hard-coded value? + # NOTE: this is a hard-coded value but might be configurable in the future qformer_compression_rate = 1 integer = audio_frames // compression_rate remainder = audio_frames % compression_rate @@ -1085,15 +1105,12 @@ def _compute_audio_embed_size(hf_config, audio_frames): return result -def get_max_phi3s_audio_tokens(ctx: InputContext): - # TODO +def get_max_phi4o_audio_tokens(ctx: InputContext) -> int: return 10000 - # return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) - -def dummy_audio_for_phi3s(audio_count: int) -> dict: +def dummy_audio_for_phi4o(audio_count: int) -> dict: """ - Create dummy audio data for the Phi-3.5-Speech model, which is used for profiling. + Create dummy audio data for the Phi-4O model, which is used for profiling. Args: audio_count (int): Number of audio samples. @@ -1110,11 +1127,11 @@ def dummy_image_for_phi3v(width: int, height: int): return image -def dummy_data_for_phi3s( +def dummy_data_for_phi4o( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ) -> DummyData: """ - Create dummy sequence (input_ids) and audio data for the Phi-3.5-Speech model, which is used for + Create dummy sequence (input_ids) and audio data for the Phi-4O model, which is used for profiling. In this case, the sequence data is a bunch of 0s with a number of audio tokens that correspond @@ -1143,7 +1160,7 @@ def dummy_data_for_phi3s( if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: raise RuntimeError( - f"Phi3O cannot process {audio_count} audios and {image_count} images in a prompt," + f"Phi4O cannot process {audio_count} audios and {image_count} images in a prompt," f"please increase max_model_len to be at larger than {audio_feature_size * audio_count + total_image_tokens}" " or reduce audio/image limit by --limit-mm-per-prompt.") @@ -1153,7 +1170,7 @@ def dummy_data_for_phi3s( (0, seq_len - audio_feature_size * audio_count), ) mm_data = { - "audio": dummy_audio_for_phi3s(audio_count), + "audio": dummy_audio_for_phi4o(audio_count), } else: seq_data = SequenceData.from_prompt_token_counts( @@ -1166,9 +1183,9 @@ def dummy_data_for_phi3s( return DummyData(seq_data, mm_data) -def input_mapper_for_phi3s(ctx: InputContext, data: object) -> MultiModalInputs: +def input_mapper_for_phi4o_audio(ctx: InputContext, data: object) -> MultiModalInputs: """ - This function is used to create the MultiModalInputs for the Phi-3.5-Speech model. + This function is used to create the MultiModalInputs for the Phi-4O (audio) model. Specifically, for audio, we extract the audio features from the sound file and create pairs of audio features and audio embed lengths (the latter of which is used to repeat the audio placeholder token in the input prompt IDs). @@ -1269,19 +1286,19 @@ def cat_with_pad(tensors, dim, padding_value=0): return output -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi3s) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4o_audio) @MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi3v) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi3s_audio_tokens + "audio", get_max_phi4o_audio_tokens ) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "image", get_max_phi3v_image_tokens ) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3s) # TODO dummy data for vision? +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4o) @INPUT_REGISTRY.register_input_processor(input_processor_for_phio) class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP): """ - Implements the Phi-3.5-Omni model in VLLM. + Implements the Phi-4-Omni model in VLLM. Args: config (PretrainedConfig): Pretrained model configuration. @@ -1407,7 +1424,7 @@ def _audio_features_to_embeddings( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Phi3SAudioInputs]: + ) -> Optional[Phi4OAudioInputs]: """ Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -1416,7 +1433,7 @@ def _parse_and_validate_audio_input( kwargs (object): Keyword arguments. Returns: - Optional[Phi3SAudioInputs]: Parsed and validated audio inputs. + Optional[Phi4OAudioInputs]: Parsed and validated audio inputs. """ audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) @@ -1431,7 +1448,7 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_features)}" ) - return Phi3SAudioFeatureInputs( + return Phi4OAudioFeatureInputs( type="audio_features", data=audio_features ) @@ -1442,24 +1459,24 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_embeds)}" ) - return Phi3SAudioEmbeddingInputs( + return Phi4OAudioEmbeddingInputs( type="audio_embeds", data=audio_embeds ) raise AssertionError("This line should be unreachable.") def _process_audio_input( - self, input_ids: torch.Tensor, audio_input: Phi3SAudioInputs, audio_projection_mode: str + self, input_ids: torch.Tensor, audio_input: Phi4OAudioInputs, audio_projection_mode: str ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is created by - `input_mapper_for_phi3s`. + `input_mapper_for_phi4o_audio`. Args: input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token replication). - audio_input (Phi3SAudioInputs): Audio input. + audio_input (Phi4OAudioInputs): Audio input. Returns: NestedTensors: Audio embeddings @@ -1576,7 +1593,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: name = name.replace(audio_proj_4v, "embed_tokens_extend.audio_projection_for_vision") name = ( - # name.replace("model.embed_tokens.", "embed_tokens.") name.replace( "model.embed_tokens_extend.audio_embed.audio_projection.speech.", "embed_tokens_extend.audio_projection.", @@ -1593,11 +1609,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: name = name.replace(".base_layer.weight", ".weight") adjusted_weights[name] = weight - # if name == "model.embed_tokens.weight": - # adjusted_weights["embed_tokens.weight"] = ( - # weight - # ) - missing_keys, unexpected_keys = self.load_state_dict( adjusted_weights, strict=False ) diff --git a/vllm/model_executor/models/phi4o_utils.py b/vllm/model_executor/models/phi4o_utils.py index 7387063a8214..6b67aae5013c 100644 --- a/vllm/model_executor/models/phi4o_utils.py +++ b/vllm/model_executor/models/phi4o_utils.py @@ -3,7 +3,6 @@ # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) # but implemented by the Phi-Speech team #!/usr/bin/env python3 -"""ConformerEncoder Module""" import abc import backoff from functools import partial From a569e4e8f6e4cbaabef9ef1e4e837c904c855458 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 13:28:03 -0800 Subject: [PATCH 05/27] Do not support Tensor Parallel and Pipeline Parallel Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4o.py | 72 +++++++++++++---------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index 2fcfe7b930db..982c576d9274 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -23,7 +23,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import get_pp_group +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext, DummyData from vllm.inputs.data import token_inputs, TokenInputs from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,7 +43,7 @@ from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .vision_siglip_navit import get_siglip_vision_model from .phi4o_utils import AudioEmbedding -from .utils import PPMissingLayer +from .utils import PPMissingLayer, maybe_prefix _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 # <|endoftext10|> (see vocab.json in hf model) @@ -1296,15 +1296,9 @@ def cat_with_pad(tensors, dim, padding_value=0): ) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4o) @INPUT_REGISTRY.register_input_processor(input_processor_for_phio) -class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP): +class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ - Implements the Phi-4-Omni model in VLLM. - - Args: - config (PretrainedConfig): Pretrained model configuration. - multimodal_config (MultiModalConfig): Multi-modal configuration. - cache_config (Optional[CacheConfig]): Cache configuration. - quant_config (Optional[QuantizationConfig]): Quantization configuration. + Implements the Phi-4-multimodal-instruct model in VLLM. """ # LoRA specific attributes packed_modules_mapping = { @@ -1336,6 +1330,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.lora_config = lora_config + # Tensor/Pipeline parallel not supported for now. + assert get_tensor_model_parallel_world_size() == 1, "tensor parallel is not supported" + assert get_pp_group().world_size == 1, "pipeline parallel is not supported" + self.vision_encoder = PhiOImageEncoder( config, quant_config, @@ -1356,35 +1354,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - self.model = LlamaModel(vllm_config=vllm_config, prefix="model") - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = Sampler() - else: - self.lm_head = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model = LlamaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() def _audio_features_to_embeddings( self, From f7b8579e0690d6e50f84ec1d75e196b5f2bd74c4 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 14:31:31 -0800 Subject: [PATCH 06/27] clean up code Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4o.py | 49 ++--------------------------- 1 file changed, 2 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index 982c576d9274..76441b1fb062 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -1567,16 +1567,6 @@ def merge_image_features_to_inputs_embeds( return merged_embeds def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: - """ - Load in the weights for VLLM. - - NOTE: I highly recommend avoiding the VLLM WeightMapper/Autoloader approach - since it's needlessly complex. - - Args: - weights (Iterable[Tuple[str, torch.Tensor]]): Weights to load (usualy from a - Hugging Face model). - """ weights = {name: weight for name, weight in weights} adjusted_weights = {} @@ -1606,10 +1596,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: missing_keys, unexpected_keys = self.load_state_dict( adjusted_weights, strict=False ) - logger.debug("--------------- missing keys -----------------") + logger.debug("*** missing keys:") for key in missing_keys: logger.debug(key) - logger.debug("--------------- unexpected keys ---------------") + logger.debug("**** unexpected keys:") for key in unexpected_keys: logger.debug(key) @@ -1622,21 +1612,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: - """ - Run the forward pass of the model. - - Args: - input_ids (torch.Tensor): Input IDs. - positions (torch.Tensor): Positions (handled by VLLM) - kv_caches (List[torch.Tensor]): Key-value caches (handled by VLLM) - attn_metadata (AttentionMetadata): Attention metadata (handled by VLLM) - intermediate_tensors (Optional[IntermediateTensors]): Intermediate tensors - (handled by VLLM) - kwargs (object): Keyword arguments. NOTE: this should contain the audio/MM input. - - Returns: - torch.Tensor: Hidden states / model output. - """ if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -1692,16 +1667,6 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - """ - Boilerplate method for computing logits (needed by the sampler). - - Args: - hidden_states (torch.Tensor): Hidden states. - sampling_metadata (SamplingMetadata): Sampling metadata. - - Returns: - Optional[torch.Tensor]: Logits. - """ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits @@ -1711,15 +1676,5 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - """ - Boilerplate method for sampling by VLLM. - - Args: - logits (torch.Tensor): Logits. - sampling_metadata (SamplingMetadata): Sampling metadata. - - Returns: - Optional[SamplerOutput]: Sampler output. - """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens From 375e2464886c8a002966469d848e2e14ad4c7e6a Mon Sep 17 00:00:00 2001 From: Yen-Chun Chen Date: Fri, 7 Feb 2025 13:53:51 -0800 Subject: [PATCH 07/27] code cleaning / renaming for the vision part Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4o.py | 69 +++++++++++------------------ 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4o.py index 76441b1fb062..679972e30882 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4o.py @@ -66,9 +66,9 @@ }, } logger = logging.get_logger(__name__) -# This is a workaround to prevent text (user input) + audio + image from being used in +# This is a workaround to prevent text (user input) + audio + image from being used in # the same prompt. -# It includes token ids for "/n" and tokens in added_tokens_decoder from the +# It includes token ids for "/n" and tokens in added_tokens_decoder from the # tokenizer_confg.json file. NON_USER_INPUT_TOKENS = {198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, 200023, 200024, 200025, 200026, 200027, 200028} @@ -82,12 +82,12 @@ def get_max_dummy_image(ctx: InputContext): vit_image_size = prepro_config['vit_image_size'] max_side = vit_image_size * dynamic_hd_size - dummy_image = dummy_image_for_phi3v(vit_image_size, max_side) + dummy_image = dummy_image_for_phi4o(vit_image_size, max_side) return dummy_image # image token length -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi4o_image_tokens(ctx: InputContext): dummy_image = get_max_dummy_image(ctx) hf_config = ctx.get_hf_config() @@ -288,12 +288,8 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): "image_attention_mask": returned_image_attention_mask, "num_img_tokens": returned_num_img_tokens, } - # data = [returned_input_image_embeds, returned_image_sizes, returned_image_attention_mask, returned_num_img_tokens] return data - # return BatchFeature(data=data, tensor_type=return_tensors) - - class PhiOImageEncoder(nn.Module): """Image embedding.""" @@ -313,8 +309,6 @@ def __init__(self, else: self.drop = None - # logger.info(f"create image tower {config.img_processor}") - # layer_idx to output the img features if isinstance(config.img_processor, dict): self.layer_idx = config.img_processor.get('layer_idx', -2) @@ -329,7 +323,7 @@ def __init__(self, L, D = pe_weight.size() H = int(math.sqrt(L)) assert H**2 == L, f'position embedding size {L} is not square' - if H % 2 != 0: #and kwargs.get('image_token_compression_cls', None) is None: + if H % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) H += 1 image_dim_out = D @@ -344,15 +338,14 @@ def __init__(self, # global_gn and sub_gn for hd transform, serves as line separator - self.use_hd_transform = True # kwargs.get('use_hd_transform', False) - self.with_learnable_separator = True # kwargs.get('with_learnable_separator', False) - self.hd_transform_order = "sub_glb" # kwargs.get('hd_transform_order', 'glb_sub') - self.freeze_img_processor = False # kwargs.get('freeze_img_processor', False) - self.crop_size = 448 # kwargs.get('crop_size', 336) - # logger.info(f'freeze_img_processor = {self.freeze_img_processor}') + self.use_hd_transform = True + self.with_learnable_separator = True + self.hd_transform_order = "sub_glb" + self.freeze_img_processor = False + self.crop_size = 448 # image token compression - self.image_token_compression_cls = 'avg_pool_2d' # kwargs.get('image_token_compression_cls', None) + self.image_token_compression_cls = 'avg_pool_2d' self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.base_feat_height_reduction = 1 self.base_feat_height_target = self.base_feat_height_target // 2 @@ -363,9 +356,7 @@ def __init__(self, # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) - # logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') - projection_cls = "mlp" # kwargs.get('projection_cls', 'linear') dim_projection = hidden_size depth = 2 layers = [nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection)] @@ -378,7 +369,7 @@ def __init__(self, self.vocab_size = config.vocab_size self.img_features = None - self.use_out_place_operations = False # kwargs.get('use_out_place_operations', False) + self.use_out_place_operations = False def get_img_features(self, img_embeds: torch.FloatTensor, @@ -414,8 +405,6 @@ def get_img_features(self, return patch_feature - - # logger.info(f'processed img feature size = {img_feature.size()}') raise NotImplementedError def forward(self, @@ -452,10 +441,6 @@ def forward(self, img_features = self.get_img_features(pixel_values, image_attention_mask.type(torch.BoolTensor).flatten(0,1).to(target_device)) - fake_image_forward = False - select = False - hd_transform = False - base_feat_height_target = self.base_feat_height_target base_resolution = self.crop_size base_feat_height_reduction = self.base_feat_height_reduction @@ -525,7 +510,6 @@ def forward(self, assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' output_len.append(temp_len) - num_img_tokens = output_len img_set_tensor = [] for _output_img in output_imgs: img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) @@ -790,7 +774,7 @@ def _compute_num_image_tokens( feat_width += 1 if non_pad_feat_height % token_compression_factor != 0: feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height # FIXME bug: 1504, should be 1536 + num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 @@ -853,7 +837,7 @@ def _get_audio_embed_sizes(audios, ctx: InputContext): audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of waveform and sample rate. ctx (InputContext): Input context. - + Returns: List[int]: List of audio embedding sizes. """ @@ -879,10 +863,10 @@ def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): waveform and sample rate. ctx (InputContext): Input context. prompt_str (str): The prompt string. - + Returns: Dict[str, List[int]]: Mapping of audio placeholder tokens to audio placeholder token ids. - + """ if len(audios) == 0: return {} @@ -1088,7 +1072,7 @@ def input_processor_for_phio( def _compute_audio_embed_size(hf_config, audio_frames): """ - Compute the audio embedding size based on the audio frames and compression rate. + Compute the audio embedding size based on the audio frames and compression rate. """ compression_rate = hf_config.embd_layer['audio_embd_layer']['compression_rate'] # NOTE: this is a hard-coded value but might be configurable in the future @@ -1122,7 +1106,7 @@ def dummy_audio_for_phi4o(audio_count: int) -> dict: return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count -def dummy_image_for_phi3v(width: int, height: int): +def dummy_image_for_phi4o(width: int, height: int): image = Image.new('RGB', (width, height), color='black') return image @@ -1155,7 +1139,7 @@ def dummy_data_for_phi4o( image_count = mm_counts["image"] dummy_image = get_max_dummy_image(ctx) - max_image_tokens = get_max_phi3v_image_tokens(ctx) + max_image_tokens = get_max_phi4o_image_tokens(ctx) total_image_tokens = image_count * max_image_tokens if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: @@ -1163,7 +1147,7 @@ def dummy_data_for_phi4o( f"Phi4O cannot process {audio_count} audios and {image_count} images in a prompt," f"please increase max_model_len to be at larger than {audio_feature_size * audio_count + total_image_tokens}" " or reduce audio/image limit by --limit-mm-per-prompt.") - + if audio_feature_size * audio_count > total_image_tokens: seq_data = SequenceData.from_prompt_token_counts( (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), @@ -1235,11 +1219,10 @@ def input_mapper_for_phi4o_audio(ctx: InputContext, data: object) -> MultiModalI return MultiModalInputs({"audio_features": audio_features}) -def input_mapper_for_phi3v(ctx: InputContext, data: object): - # data: list of PIL images - # assert isinstance(data, list), "Data must be a list of PIL images" +def input_mapper_for_phi4o_image(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] + # data: list of PIL images if len(data) == 0: return MultiModalInputs() hf_config = ctx.get_hf_config() @@ -1287,12 +1270,12 @@ def cat_with_pad(tensors, dim, padding_value=0): @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4o_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi3v) +@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi4o_image) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_phi4o_audio_tokens ) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi3v_image_tokens + "image", get_max_phi4o_image_tokens ) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4o) @INPUT_REGISTRY.register_input_processor(input_processor_for_phio) @@ -1324,7 +1307,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert multimodal_config, "multimodal_config is required" quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - + self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config @@ -1379,7 +1362,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, logit_scale) self.sampler = Sampler() - + def _audio_features_to_embeddings( self, input_ids: torch.Tensor, From f8a3373e539b764213f10d43b8d028548b0f8975 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 16:17:39 -0800 Subject: [PATCH 08/27] rename phi4o with phi4mm Signed-off-by: Congcong Chen --- vllm/model_executor/models/{phi4o.py => phi4mm.py} | 2 +- vllm/model_executor/models/{phi4o_utils.py => phi4mm_utils.py} | 0 vllm/model_executor/models/registry.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename vllm/model_executor/models/{phi4o.py => phi4mm.py} (99%) rename vllm/model_executor/models/{phi4o_utils.py => phi4mm_utils.py} (100%) diff --git a/vllm/model_executor/models/phi4o.py b/vllm/model_executor/models/phi4mm.py similarity index 99% rename from vllm/model_executor/models/phi4o.py rename to vllm/model_executor/models/phi4mm.py index 679972e30882..572f82bd0094 100644 --- a/vllm/model_executor/models/phi4o.py +++ b/vllm/model_executor/models/phi4mm.py @@ -42,7 +42,7 @@ from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .vision_siglip_navit import get_siglip_vision_model -from .phi4o_utils import AudioEmbedding +from .phi4mm_utils import AudioEmbedding from .utils import PPMissingLayer, maybe_prefix diff --git a/vllm/model_executor/models/phi4o_utils.py b/vllm/model_executor/models/phi4mm_utils.py similarity index 100% rename from vllm/model_executor/models/phi4o_utils.py rename to vllm/model_executor/models/phi4mm_utils.py diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 76114f471ad3..4775766ccc5a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -181,7 +181,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), - "PhiOForCausalLM": ("phi4o", "PhiOForCausalLM"), + "PhiOForCausalLM": ("phi4mm", "PhiOForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 From 7b73371c5ad82b88c56bb35482fa5a587a979535 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 16:24:22 -0800 Subject: [PATCH 09/27] refactor change phi4o to phi4mm Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4mm.py | 44 +++++++++++++------------- vllm/model_executor/models/registry.py | 2 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 572f82bd0094..9badc4fa1041 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -82,12 +82,12 @@ def get_max_dummy_image(ctx: InputContext): vit_image_size = prepro_config['vit_image_size'] max_side = vit_image_size * dynamic_hd_size - dummy_image = dummy_image_for_phi4o(vit_image_size, max_side) + dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) return dummy_image # image token length -def get_max_phi4o_image_tokens(ctx: InputContext): +def get_max_phi4mm_image_tokens(ctx: InputContext): dummy_image = get_max_dummy_image(ctx) hf_config = ctx.get_hf_config() @@ -930,7 +930,7 @@ def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): return image_id_to_input_ids -def input_processor_for_phio( +def input_processor_for_phi4mm( ctx: InputContext, inputs: DecoderOnlyInputs ) -> TokenInputs: """ @@ -1016,7 +1016,7 @@ def input_processor_for_phio( break input_ids.extend(curr_token_ids) if has_audio and has_imag and has_user_text_input: - raise ValueError("PhiOForCausalLM does not support text + audio + image" + + raise ValueError("Phi4MMForCausalLM does not support text + audio + image" + " inputs in the same prompt") # Handle the case where the prompt is already tokenized else: @@ -1044,7 +1044,7 @@ def input_processor_for_phio( i += token_count if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: - raise ValueError("PhiOForCausalLM does not support text + audio + image" + + raise ValueError("Phi4MMForCausalLM does not support text + audio + image" + " inputs in the same prompt") # If the below assertion fails, it might be that input pure-text # messages contain image/audio special tokens literally @@ -1089,10 +1089,10 @@ def _compute_audio_embed_size(hf_config, audio_frames): return result -def get_max_phi4o_audio_tokens(ctx: InputContext) -> int: +def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: return 10000 -def dummy_audio_for_phi4o(audio_count: int) -> dict: +def dummy_audio_for_phi4mm(audio_count: int) -> dict: """ Create dummy audio data for the Phi-4O model, which is used for profiling. @@ -1106,12 +1106,12 @@ def dummy_audio_for_phi4o(audio_count: int) -> dict: return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count -def dummy_image_for_phi4o(width: int, height: int): +def dummy_image_for_phi4mm(width: int, height: int): image = Image.new('RGB', (width, height), color='black') return image -def dummy_data_for_phi4o( +def dummy_data_for_phi4mm( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ) -> DummyData: """ @@ -1139,7 +1139,7 @@ def dummy_data_for_phi4o( image_count = mm_counts["image"] dummy_image = get_max_dummy_image(ctx) - max_image_tokens = get_max_phi4o_image_tokens(ctx) + max_image_tokens = get_max_phi4mm_image_tokens(ctx) total_image_tokens = image_count * max_image_tokens if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: @@ -1154,7 +1154,7 @@ def dummy_data_for_phi4o( (0, seq_len - audio_feature_size * audio_count), ) mm_data = { - "audio": dummy_audio_for_phi4o(audio_count), + "audio": dummy_audio_for_phi4mm(audio_count), } else: seq_data = SequenceData.from_prompt_token_counts( @@ -1167,7 +1167,7 @@ def dummy_data_for_phi4o( return DummyData(seq_data, mm_data) -def input_mapper_for_phi4o_audio(ctx: InputContext, data: object) -> MultiModalInputs: +def input_mapper_for_phi4mm_audio(ctx: InputContext, data: object) -> MultiModalInputs: """ This function is used to create the MultiModalInputs for the Phi-4O (audio) model. Specifically, for audio, we extract the audio features from the sound file and create @@ -1219,7 +1219,7 @@ def input_mapper_for_phi4o_audio(ctx: InputContext, data: object) -> MultiModalI return MultiModalInputs({"audio_features": audio_features}) -def input_mapper_for_phi4o_image(ctx: InputContext, data: object): +def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] # data: list of PIL images @@ -1269,17 +1269,17 @@ def cat_with_pad(tensors, dim, padding_value=0): return output -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4o_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi4o_image) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4mm_audio) +@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi4mm_image) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi4o_audio_tokens + "audio", get_max_phi4mm_audio_tokens ) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi4o_image_tokens + "image", get_max_phi4mm_image_tokens ) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4o) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phio) -class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) +@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in VLLM. """ @@ -1296,7 +1296,7 @@ class PhiOForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj" ] - # PhiOForCausalLM does not apply LoRA to the embedding layer. + # Phi4MMForCausalLM does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] @@ -1448,7 +1448,7 @@ def _process_audio_input( """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is created by - `input_mapper_for_phi4o_audio`. + `input_mapper_for_phi4mm_audio`. Args: input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4775766ccc5a..fbdbbbc0c4b5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -181,7 +181,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), - "PhiOForCausalLM": ("phi4mm", "PhiOForCausalLM"), + "PhiOForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 From 30a70d5412aac59da4bfd09d7e43707b8a4c1526 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 16:27:05 -0800 Subject: [PATCH 10/27] refactor change phi4o to phi4mm continued Signed-off-by: Congcong Chen --- vllm/lora/models.py | 3 -- vllm/model_executor/models/phi4mm.py | 32 +++++++++++----------- vllm/model_executor/models/phi4mm_utils.py | 2 +- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 26fe835d02e8..e1294884ac2a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -389,7 +389,6 @@ def activate_adapter( for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: - logger.debug(f"setting {module_name}") module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. bias = module_lora.bias @@ -405,7 +404,6 @@ def activate_adapter( module_lora.embeddings_tensor, module_lora.bias) else: - logger.debug(f"resetting {module_name}") module.reset_lora(index) return True @@ -507,7 +505,6 @@ def _create_lora_modules(self): # aims to prevent this error if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): - logger.debug("-------- Skipping %s --------", module_name) continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 9badc4fa1041..11cb97844583 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -291,7 +291,7 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): return data -class PhiOImageEncoder(nn.Module): +class Phi4MMImageEncoder(nn.Module): """Image embedding.""" def __init__(self, @@ -518,19 +518,19 @@ def forward(self, return img_set_tensor -class Phi4OAudioFeatureInputs(TypedDict): +class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: Tuple[NestedTensors] """Shape: `((batch_size, num_audios, 80, M), )""" -class Phi4OAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -Phi4OAudioInputs = Union[Phi4OAudioFeatureInputs, Phi4OAudioEmbeddingInputs] +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): @@ -726,7 +726,7 @@ def _compute_num_image_tokens( architecture and exclude output features containing only padding pixels for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map - NOTE right now, Phi-O uses hard-coded token_compression_factor=2 + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, "vit_image_size must be divisible by vit_patch_size" assert vit_image_size // vit_patch_size % token_compression_factor == 0, "vit_image_size // vit_patch_size must be divisible by token_compression_factor" @@ -1094,7 +1094,7 @@ def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: def dummy_audio_for_phi4mm(audio_count: int) -> dict: """ - Create dummy audio data for the Phi-4O model, which is used for profiling. + Create dummy audio data for the Phi4MM model, which is used for profiling. Args: audio_count (int): Number of audio samples. @@ -1115,7 +1115,7 @@ def dummy_data_for_phi4mm( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ) -> DummyData: """ - Create dummy sequence (input_ids) and audio data for the Phi-4O model, which is used for + Create dummy sequence (input_ids) and audio data for the Phi4MM model, which is used for profiling. In this case, the sequence data is a bunch of 0s with a number of audio tokens that correspond @@ -1144,7 +1144,7 @@ def dummy_data_for_phi4mm( if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: raise RuntimeError( - f"Phi4O cannot process {audio_count} audios and {image_count} images in a prompt," + f"Phi4MM cannot process {audio_count} audios and {image_count} images in a prompt," f"please increase max_model_len to be at larger than {audio_feature_size * audio_count + total_image_tokens}" " or reduce audio/image limit by --limit-mm-per-prompt.") @@ -1169,7 +1169,7 @@ def dummy_data_for_phi4mm( def input_mapper_for_phi4mm_audio(ctx: InputContext, data: object) -> MultiModalInputs: """ - This function is used to create the MultiModalInputs for the Phi-4O (audio) model. + This function is used to create the MultiModalInputs for the Phi4MM (audio) model. Specifically, for audio, we extract the audio features from the sound file and create pairs of audio features and audio embed lengths (the latter of which is used to repeat the audio placeholder token in the input prompt IDs). @@ -1317,7 +1317,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert get_tensor_model_parallel_world_size() == 1, "tensor parallel is not supported" assert get_pp_group().world_size == 1, "pipeline parallel is not supported" - self.vision_encoder = PhiOImageEncoder( + self.vision_encoder = Phi4MMImageEncoder( config, quant_config, prefix="model.vision_embed_tokens", @@ -1401,7 +1401,7 @@ def _audio_features_to_embeddings( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Phi4OAudioInputs]: + ) -> Optional[Phi4MMAudioInputs]: """ Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -1410,7 +1410,7 @@ def _parse_and_validate_audio_input( kwargs (object): Keyword arguments. Returns: - Optional[Phi4OAudioInputs]: Parsed and validated audio inputs. + Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. """ audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) @@ -1425,7 +1425,7 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_features)}" ) - return Phi4OAudioFeatureInputs( + return Phi4MMAudioFeatureInputs( type="audio_features", data=audio_features ) @@ -1436,14 +1436,14 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_embeds)}" ) - return Phi4OAudioEmbeddingInputs( + return Phi4MMAudioEmbeddingInputs( type="audio_embeds", data=audio_embeds ) raise AssertionError("This line should be unreachable.") def _process_audio_input( - self, input_ids: torch.Tensor, audio_input: Phi4OAudioInputs, audio_projection_mode: str + self, input_ids: torch.Tensor, audio_input: Phi4MMAudioInputs, audio_projection_mode: str ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of @@ -1453,7 +1453,7 @@ def _process_audio_input( Args: input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token replication). - audio_input (Phi4OAudioInputs): Audio input. + audio_input (Phi4MMAudioInputs): Audio input. Returns: NestedTensors: Audio embeddings diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 6b67aae5013c..e08a377f7d24 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -3354,7 +3354,7 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: self.audio_embed_sizes = None def post_init(self, audio_config): - # execute after the from_pretrained() initialization of the phi3 model + # execute after the from_pretrained() initialization of the phi model if audio_config.get("name", None) == "cascades": init_model_config = audio_config.get("init_model", {}) self.encoder.post_init(init_model_config) From 99a636dbc24480775833db5956b3a5fc4ab4a8fd Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 7 Feb 2025 16:41:22 -0800 Subject: [PATCH 11/27] final update to change phio to phi4mm Signed-off-by: Congcong Chen --- vllm/entrypoints/chat_utils.py | 4 ++-- vllm/model_executor/models/registry.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6a1706d828b2..8f906cf1d80b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -395,7 +395,7 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" - if model_type == "phio": + if model_type == "phi4mm": return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" @@ -426,7 +426,7 @@ def _placeholder_str(self, modality: ModalityStr, elif modality == "audio": if model_type == "ultravox": return "<|audio|>" - if model_type == "phio": + if model_type == "phi4mm": return "<|endoftext11|>" # 200011 (see vocab.json in hf model) if model_type == "qwen2_audio": return (f"Audio {current_count}: " diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fbdbbbc0c4b5..dbdcf2e1cf4a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -181,7 +181,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), - "PhiOForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), + "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 From 5dcf783a388d71e9e0fb8c7a7fd982dee4d27168 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 25 Feb 2025 15:01:36 -0800 Subject: [PATCH 12/27] Fix errors after rebasing to the top of the main Signed-off-by: Congcong Chen --- examples/offline_inference_phi3o.py | 6 +-- requirements-common.txt | 2 + vllm/model_executor/models/phi4mm.py | 8 +-- vllm/model_executor/models/phi4mm_utils.py | 60 ---------------------- 4 files changed, 7 insertions(+), 69 deletions(-) diff --git a/examples/offline_inference_phi3o.py b/examples/offline_inference_phi3o.py index 3488ff3d80f3..9df355a4be19 100644 --- a/examples/offline_inference_phi3o.py +++ b/examples/offline_inference_phi3o.py @@ -460,7 +460,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--model-path", "-p", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/MoE_weijian_phio-final_2-newtxtsftmore-hf", + default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm", help="Path to the (HuggingFace) model checkpoint.", ) @@ -468,7 +468,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--vision-lora-path", "-v", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/vision-lora-only-from-hf-unified-model", + default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/vision-lora", help="Path to the (HuggingFace) vision lora model checkpoint.", ) @@ -476,7 +476,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--speech-lora-path", "-s", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/final_checkpoint_new/speech-lora-only-from-hf-unified-model", + default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/speech-lora", help="Path to the (HuggingFace) speech lora model checkpoint.", ) diff --git a/requirements-common.txt b/requirements-common.txt index fb84d6d9e7b6..77a8ff701817 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -38,3 +38,5 @@ depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md +scipy # Required for phi-4-multimodal-instruct +flash_attn # Required for phi-4-multimodal-instruct \ No newline at end of file diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 11cb97844583..40b5c2401542 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -36,7 +36,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalInputs, NestedTensors -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.sequence import IntermediateTensors, SequenceData from transformers.utils import logging @@ -997,7 +997,7 @@ def input_processor_for_phi4mm( prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] # Create the new input_ids with the placholder image and audio tokens inserted - tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + tokenizer = cached_tokenizer_from_config(ctx.model_config) input_ids = [] has_imag, has_audio, has_user_text_input = False, False, False for prompt_chunk_string in prompt_chunk_strings: @@ -1590,8 +1590,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: @@ -1637,8 +1635,6 @@ def forward( hidden_states = self.model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds, ) diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index e08a377f7d24..edfa397a8fb2 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -4,7 +4,6 @@ # but implemented by the Phi-Speech team #!/usr/bin/env python3 import abc -import backoff from functools import partial import math from typing import Optional, Tuple, Union, List, Literal, Union, Dict, Callable @@ -876,20 +875,6 @@ def forward(self, x: torch.Tensor): #### forward embedding layers starts here - - -@backoff.on_exception(backoff.expo, Exception, max_tries=10) -def np_loadtxt_with_retry(filepath): - """np.loadtxt with retry - - Args: - filepath: str - file path to the numpy array. - """ - result = np.loadtxt(filepath, dtype="f") - return result - - class MeanVarianceNormLayer(nn.Module): """Mean/variance normalization layer. @@ -918,36 +903,6 @@ def forward(self, input_: Tensor) -> Tensor: """ return (input_ - self.global_mean) * self.global_invstd - def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): - """Load feature mean and variance used for normalization. - - Args: - mean_file: str - path to the feature mean statistics file. - invstd_file: str - path to the features inverted standard deviation - statistics file. - cuside_features: bool - Boolean that indicates CUSIDE is being used. - The statistics of CUSIDE features are copied - from the normal features - """ - self.global_mean.data = torch.from_numpy( - np_loadtxt_with_retry(mean_file) - ) - self.global_invstd.data = torch.from_numpy( - np_loadtxt_with_retry(invstd_file) - ) - - if cuside_features: - self.global_mean.data = torch.cat( - (self.global_mean.data, self.global_mean.data), 0 - ) - self.global_invstd.data = torch.cat( - (self.global_invstd.data, self.global_invstd.data), 0 - ) - - class CausalConv1D(nn.Conv1d): """ A causal version of nn.Conv1d where each step would have limited access to locations on its right or left @@ -2472,11 +2427,6 @@ def post_init(self, init_model_config): self.encoder_embedding_config["input_size"] ) - mean_file = init_model_config.get("mean_file", None) - invstd_file = init_model_config.get("invstd_file", None) - if mean_file is not None and invstd_file is not None: - self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) - def compute_lens_change(self, feature_lens): """feature_lens: int return updated feature lens. @@ -3353,16 +3303,6 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: self.input_embeds = None self.audio_embed_sizes = None - def post_init(self, audio_config): - # execute after the from_pretrained() initialization of the phi model - if audio_config.get("name", None) == "cascades": - init_model_config = audio_config.get("init_model", {}) - self.encoder.post_init(init_model_config) - # remove the init model in config so it is not saved in the config. - # This might affect the model loading in resuming training and decoding. - if "init_model" in audio_config: - audio_config.pop("init_model") - def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: self.input_embeds = input_embeds From 707dfe15c2ebc977ab0912e3a9bc4df4c3c7f219 Mon Sep 17 00:00:00 2001 From: Vadim Mazalov Date: Wed, 26 Feb 2025 05:44:04 +0000 Subject: [PATCH 13/27] Refactor phimm_utils Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4mm.py | 2 +- vllm/model_executor/models/phi4mm_audio.py | 1434 ++++++++++++++++++++ vllm/model_executor/models/phi4mm_utils.py | 1423 +------------------ 3 files changed, 1436 insertions(+), 1423 deletions(-) create mode 100644 vllm/model_executor/models/phi4mm_audio.py diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 40b5c2401542..d6313bf6ebfa 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -42,7 +42,7 @@ from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA from .vision_siglip_navit import get_siglip_vision_model -from .phi4mm_utils import AudioEmbedding +from .phi4mm_audio import AudioEmbedding from .utils import PPMissingLayer, maybe_prefix diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py new file mode 100644 index 000000000000..c9a47ffb7e73 --- /dev/null +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -0,0 +1,1434 @@ + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +import abc +from functools import partial +import math +import torch +import numpy as np +from typing import Callable, Dict, List, Literal, Optional, Union +import torch.nn.functional as F +from torch import nn, Tensor +from transformers import PretrainedConfig +from torch.utils.checkpoint import checkpoint +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + checkpoint_wrapper, + offload_wrapper, + CheckpointImpl, +) +from vllm.model_executor.models.phi4mm_utils import AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, get_offset, repeat, unfold_tensor, validate_checkpointing_config +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel, +) + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + +def encoder_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], + layer_cls: type, + idx: int = 0, +) -> Callable: + """return encoder activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + if target_layer_cls.lower() == "transformer": + target_layer_cls = ( + "EncoderLayer", + "ConformerEncoderLayer", + ) + elif target_layer_cls.lower() == "attention": + target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") + checkpointing_interval = activation_checkpointing.get("interval", 1) + offloading = activation_checkpointing.get("offload", False) + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", True) + else CheckpointImpl.NO_REENTRANT + ) + + if ( + idx % checkpointing_interval == 0 + and layer_cls.__name__ in target_layer_cls + ): + if offloading: + return offload_wrapper + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + + raise ValueError("Invalid activation_checkpointing config") + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, otional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = encoder_checkpoint_wrapper( + activation_checkpointing, + MultiHeadedAttention, + )( + MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal[ + "feat", "feat_time", "none", True + ] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args + else None + ) + if self.relative_attention_bias_type == "t5": + assert ( + self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get( + "t5_bias_symmetric", False + ), + ) + else: + raise NotImplementedError + + def post_init(self, init_model_config): + + pretrained_speech_encoder_path = init_model_config.get( + "pretrained_speech_encoder_path", None + ) + if pretrained_speech_encoder_path: + model_state = torch.load( + pretrained_speech_encoder_path, map_location="cpu" + ) + encoder_state_dict = {} + for k, v in model_state.items(): + if "encoder." in k: + tmp_k = k.replace("encoder.", "") + encoder_state_dict[tmp_k] = v + + if hasattr(self, "encoder_embedding"): + del self.encoder_embedding + self.load_state_dict(encoder_state_dict) + + if not hasattr(self, "encoder_embedding"): + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"] + ) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that computed + the right thing. That does not work within Torchscript. If you really + need this to be faster, create nn.Module()-s for all the cases and return + one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = ( + math.ceil if isinstance(feature_lens, int) else torch.ceil + ) + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list" + ) + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor + ) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings( + self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None + ): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The squence length after time reduction is invalid: {seq_len}. + Your input feature is too short. Consider filtering out the very + short sentence from data loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the lang_dict, + only used for multiseed/multilingual models. default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, otional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention + in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming decoding, use + "replication" padding for the cache at start of utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal[ + "feat", "feat_time", "none", True + ] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( + self.embed + ) + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding + ) + assert ( + self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = repeat( + num_blocks, + lambda i: encoder_checkpoint_wrapper( + activation_checkpointing, ConformerEncoderLayer, i + )( + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing( + activation_checkpointing, i + ), + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + ), + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = ( + torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) + < padding_length.unsqueeze(1) + ) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks + ) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 #maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask + + layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias( + input_tensor + ) + + _simplified_path = ( + self.extra_layer_output_idx == -1 + and relative_attention_bias is None + ) + + if _simplified_path: + input_tensor, *_ = self.encoders( + input_tensor, pos_k, pos_v, hs_mask + ) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + if i == self.extra_layer_output_idx: + layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + def gradient_checkpointing_enable(self): + pass + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) + self.window_size = window_size + self.gradient_checkpointing_enable = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = False + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + if self.gradient_checkpointing_enable and self.training: + q = checkpoint( + layer.__call__, + q, + embed_chunk, + None, + mask, + use_reentrant=True, + ) + else: + q = layer( + tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask + ) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = ( + config.n_embd if hasattr(config, "n_embd") else config.hidden_size + ) + + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = ( + config.embd_pdrop + if hasattr(config, "embd_pdrop") + else config.embed_pdrop + ) + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + # if isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == 'whisper': + # model_path = config.audio_processor.get('pretrained_model_path', None) + # whisper_model = WhisperModel.from_pretrained(model_path) + + # self.encoder = whisper_model.encoder + # n_mels = self.encoder.num_mel_bins + # audio_dim_out = self.encoder.layers[0].embed_dim + # elif isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == "w2vbert2": + # audio_processor_path = config.audio_processor.get("model_path", "facebook/w2v-bert-2.0") + # self.encoder = Wav2Vec2BertModel.from_pretrained(audio_processor_path) + # audio_dim_out = self.encoder.config.hidden_size + # self.layer_idx = config.audio_processor.get("layer", 18) + # self.encoder.config.apply_spec_augment = False + # self.encoder.config.mask_time_prob = 0 + # self.encoder.config.output_hidden_states = True + # n_mels = 160 + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + # fake initialization, create encoder_embedding layer only so that + # in decoding, all parameters can be loaded in from_pretrained_function + # in training, we do post init after from_pretrained function to make sure the correct initialization + self.encoder.post_init({}) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError(f"") + + assert ( + audio_dim_out is not None + ), "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get( + "freeze_audio_processor", False + ) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert ( + self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + self.conv_ds = None + + enable_gradient_checkpointing = kwargs.get( + "enable_gradient_checkpointing", False + ) + if enable_gradient_checkpointing: + self.encoder.gradient_checkpointing_enable() + + if self.qformer: + self.qformer.enable_gradient_checkpointing() + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) + layers = [ + nn.Linear( + audio_dim_out * self.linear_downsample_rate, dim_projection + ) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), nn.Linear(dim_projection, dim_projection)] + ) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a seperate projection layer + layers = [ + nn.Linear( + audio_dim_out * self.linear_downsample_rate, dim_projection + ) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), nn.Linear(dim_projection, dim_projection)] + ) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented" + ) + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes( + self, audio_embed_sizes: torch.LongTensor + ) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ): + + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder( + input_embeds, audio_attention_mask + ) + else: + audio_features, masks = self.encoder( + input_embeds, audio_attention_mask + ) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == 'speech': + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == 'vision': + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError(f"audio_projection_mode = {audio_projection_mode} not implemented") + + return audio_set_tensor + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + audio_embed_sizes, + **kwargs, + ) -> torch.FloatTensor: + """ + arguments: + input_ids: input text ids (B, U) + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + assert input_embeds is not None and len(input_embeds) == len( + audio_embed_sizes + ) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=False) + + if not isinstance(input_embeds, list): + input_embeds = [input_embeds] + + audio_projection_mode = kwargs.get("audio_projection_mode", "speech") + audio_set_tensor = [ + self.get_audio_features(input_embed, audio_projection_mode=audio_projection_mode) + for input_embed in input_embeds + ] + + with torch.no_grad(): + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + if "wte" in kwargs: + # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. + hidden_states = kwargs["wte"](input_ids) + else: + # otherwise, we use token embedding in pretrained mixformer from phi team + hidden_states = self.wte(input_ids) + + if len(positions.tolist()) > 0: + assert sum(audio_embed_sizes) == len( + positions + ), "please ensure the encoder outputs have the same length as defined in input_ids!" + idx = 0 + for i in range(len(audio_embed_sizes)): + cnt = audio_embed_sizes[i] + assert audio_set_tensor[i].shape[0] == 1 + hidden_states[ + positions[idx, 0], + positions[idx, 1] : positions[idx, 1] + cnt, + ] = ( + audio_set_tensor[i][0, : audio_embed_sizes[i], :] + .to(hidden_states.dtype) + .to(hidden_states.device) + ) + idx += cnt + + else: + if self.training: + # hidden_states[:, 0:img_set_tensor.shape[0]] = hidden_states[:, 0:img_set_tensor.shape[0]] + 0 * img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + hidden_states[:, 0:1] = hidden_states[ + :, 0:1 + ] + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype).to( + hidden_states.device + ) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index edfa397a8fb2..787e4508419d 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -3,29 +3,18 @@ # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) # but implemented by the Phi-Speech team #!/usr/bin/env python3 -import abc from functools import partial import math -from typing import Optional, Tuple, Union, List, Literal, Union, Dict, Callable +from typing import Optional, Tuple, Union, Union, Dict, Callable -import numpy as np import torch import torch.nn.functional as F from torch import nn, Tensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper, checkpoint_wrapper, offload_wrapper, CheckpointImpl, ) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel, -) -from torch.utils.checkpoint import checkpoint -from transformers import PretrainedConfig - -_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> - class Block(nn.Module): """Block abstract module""" @@ -1946,50 +1935,6 @@ def embedding_checkpoint_wrapper( raise ValueError("Invalid activation_checkpointing config") -def encoder_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], - layer_cls: type, - idx: int = 0, -) -> Callable: - """return encoder activation checkpoint wrapper""" - validate_checkpointing_config(activation_checkpointing) - - if isinstance(activation_checkpointing, str): - if activation_checkpointing: - if activation_checkpointing == "offload": - return offload_wrapper - return partial(checkpoint_wrapper) - return lambda x: x - - if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get("module", "transformer") - if target_layer_cls.lower() == "transformer": - target_layer_cls = ( - "EncoderLayer", - "ConformerEncoderLayer", - ) - elif target_layer_cls.lower() == "attention": - target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") - checkpointing_interval = activation_checkpointing.get("interval", 1) - offloading = activation_checkpointing.get("offload", False) - impl = ( - CheckpointImpl.REENTRANT - if activation_checkpointing.get("reentrant", True) - else CheckpointImpl.NO_REENTRANT - ) - - if ( - idx % checkpointing_interval == 0 - and layer_cls.__name__ in target_layer_cls - ): - if offloading: - return offload_wrapper - return partial(checkpoint_wrapper, checkpoint_impl=impl) - return lambda x: x - - raise ValueError("Invalid activation_checkpointing config") - - def attn_checkpointing( activation_checkpointing: Union[str, Dict], i ) -> Union[str, Dict]: @@ -2028,598 +1973,6 @@ def repeat(repeat_num, module_gen_fn): """ return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) - -class ConformerEncoderLayer(nn.Module): - """ConformerEncoder Layer module. - for more details see conformer paper: - https://arxiv.org/abs/2005.08100 - This module implement the Conformer block layer. - - Args: - d_model: int - attention dim. - ext_pw_out_channel: int - if > 0, ext_pw_out_channel is a dim channel size - for the last pointwise conv after swish activation. - depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. - depthwise_multiplier: int - number of input_dim channels duplication. this value - will be used to compute the hidden channels of the Conv1D. - n_head: int - the number of heads for multihead attention module. - d_ffn: int - output size of the feed_forward blocks. - ext_pw_kernel_size: int - kernel size of the conv pointwise of the conformer. - kernel_size: int - kernel size. - dropout_rate: float - dropout rate. - causal: bool, optional - if set to True, convolution have no access - to future frames. default False. - batch_norm: bool, optional - if set to True, apply batchnorm before activation - in ConvModule layer of the conformer. - default False - activation: str, optional - activation function name, - one of ["relu", "swish", "sigmoid"], - sigmoid activation is only used with "glu_in_fnn=True", - default "relu". - chunk_se: int, optional - 0 for offline SE. - 1 for streaming SE, where mean is computed - by accumulated history until current chunk_se. - 2 for streaming SE, where mean is computed - by only the current chunk. - default 0. - chunk_size: int, optional - chunk_size for cnn. default 18 - conv_activation: str, optional - activation function used in ConvModule part - of the conformer, default "relu". - conv_glu_type: str, optional - activation function used for the glu inside - the ConvModule part of the conformer. - default: "sigmoid". - bias_in_glu: bool, optional - if set to True, use additive bias in the weight module - before GLU. - linear_glu_in_convm: bool, optional - if set to True, use GLULinear module, - otherwise, used GLUPointWiseConv module. - default to False. - attention_innner_dim: int, otional - if equal to -1, attention dim for linears k/q/v is - equal to d_model. otherwise attention_innner_dim is used. - default -1. - attention_glu_type: str, optional - activation function for glu used in the multihead attention, - default "swish". - activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where - "module": str - accept ["transformer", "attention"] to select - which module should do activation checkpointing. - "interval": int, default 1, - interval of applying activation checkpointing, - interval = 1 means that we apply checkpointing - on every layer (if activation), otherwise, - we apply it every x interval. - "offload": bool, default False, - if set to True, we offload activation to cpu and - reload it during backward, otherwise, - we recalculate activation in backward. - default "". - export: bool, optional - if set to True, it remove the padding from convolutional layers - and allow the onnx conversion for inference. - default False. - use_pt_scaled_dot_product_attention: bool, optional - if set to True, use pytorch's scaled dot product attention implementation in training. - attn_group_sizes: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attn_group_sizes < attention_heads = Grouped-Query Attention - attn_group_sizes = attenion_heads = Multi-Query Attention - """ - - def __init__( - self, - d_model=512, - ext_pw_out_channel=0, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - n_head=4, - d_ffn=2048, - ext_pw_kernel_size=1, - kernel_size=3, - dropout_rate=0.1, - causal=False, - batch_norm=False, - activation="relu", - chunk_se=0, - chunk_size=18, - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_innner_dim=-1, - attention_glu_type="swish", - activation_checkpointing="", - export=False, - use_pt_scaled_dot_product_attention=False, - attn_group_sizes: int = 1, - ): - super().__init__() - - self.feed_forward_in = FeedForward( - d_model=d_model, - d_inner=d_ffn, - dropout_rate=dropout_rate, - activation=activation, - bias_in_glu=bias_in_glu, - ) - - self.self_attn = encoder_checkpoint_wrapper( - activation_checkpointing, - MultiHeadedAttention, - )( - MultiHeadedAttention( - n_head, - d_model, - dropout_rate, - attention_innner_dim, - attention_glu_type, - bias_in_glu, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - group_size=attn_group_sizes, - ) - ) - self.conv = ConvModule( - d_model, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal, - batch_norm, - chunk_se, - chunk_size, - conv_activation, - conv_glu_type, - bias_in_glu, - linear_glu_in_convm, - export=export, - ) - - self.feed_forward_out = FeedForward( - d_model=d_model, - d_inner=d_ffn, - dropout_rate=dropout_rate, - activation=activation, - bias_in_glu=bias_in_glu, - ) - - self.layer_norm_att = nn.LayerNorm(d_model) - self.layer_norm = nn.LayerNorm(d_model) - - def forward( - self, - x, - pos_k, - pos_v, - mask, - relative_attention_bias: Optional[Tensor] = None, - ): - """ConformerEncoder forward. - - Args: - x: torch.Tensor - input feature of shape (batch, max_time_in, size) - pos_k: torch.Tensor - positional key embedding. - mask: torch.Tensor - mask for x (batch, max_time_in) - relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) - """ - x = x + 0.5 * self.feed_forward_in(x) - norm_x = self.layer_norm_att(x) - - x = x + self.self_attn( - norm_x, - norm_x, - norm_x, - pos_k, - pos_v, - mask, - relative_attention_bias=relative_attention_bias, - ) - x = x + self.conv(x) - x = x + 0.5 * self.feed_forward_out(x) - - out = self.layer_norm(x) - - return out, pos_k, pos_v, mask - - -class TransformerEncoderBase(abc.ABC, nn.Module): - """The Base class for Transformer based encoders - - Please set causal = True in streaming model - Args: - input_size: int - input feature dimension. - chunk_size: int, list(int) - Number of frames for each chunk - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training - Some examples for the 2 cases: - chunk_size = 12 - chunk_size = [6, 8, 12, 24] - left_chunk: int, list(int) - Number of chunks used for masking in streaming mode. - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training. When - chunk_size is a list, left_chunk must be a list with same length. - Some examples for the 2 cases: - left_chunk = 6 - left_chunk = [12, 9, 6, 3] - attention_dim: int, optional - attention dimension. default 256. - attention_heads: int, optional - the number of heads. default 4 - input_layer: str, optional - input layer type before Conformer, - one of ["linear", "conv2d", "custom", "vgg2l", "embed"], - default "conv2d" - cnn_out: int, optional - the number of CNN channels before Conformer. - default -1. - cnn_layer_norm: bool, optional - layer norm between Conformer and the first CNN. - default False. - time_reduction: int, optional - time reduction factor - default 4 - dropout_rate: float, optional - dropout rate. default 0.1 - padding_idx: int, optional - padding index for input_layer=embed - default -1 - relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias - usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) - positional_dropout_rate: float, optional - dropout rate after positional encoding. default 0.0 - nemo_conv_settings: dict, optional - A dictionary of settings for NeMo Subsampling. - default None - conv2d_extra_padding: str, optional - Add extra padding in conv2d subsampling layers. Choices are - (feat, feat_time, none, True). - if True or feat_time, the extra padding is added into non full - supraframe utts in batch. - Default: none - attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention - attention_group_size = attenion_heads = Multi-Query Attention - """ - - def __init__( - self, - input_size, - chunk_size, - left_chunk, - attention_dim=256, - attention_heads=4, - input_layer="nemo_conv", - cnn_out=-1, - cnn_layer_norm=False, - time_reduction=4, - dropout_rate=0.0, - padding_idx=-1, - relative_attention_bias_args=None, - positional_dropout_rate=0.0, - nemo_conv_settings=None, - conv2d_extra_padding: Literal[ - "feat", "feat_time", "none", True - ] = "none", - attention_group_size=1, - encoder_embedding_config=None, - ): - super().__init__() - self.input_size = input_size - self.input_layer = input_layer - self.chunk_size = chunk_size - self.left_chunk = left_chunk - self.attention_dim = attention_dim - self.num_heads = attention_heads - self.attention_group_size = attention_group_size - self.time_reduction = time_reduction - self.nemo_conv_settings = nemo_conv_settings - self.encoder_embedding_config = encoder_embedding_config - - if self.input_layer == "nemo_conv": - default_nemo_conv_settings = { - "subsampling": "dw_striding", - "subsampling_factor": self.time_reduction, - "feat_in": input_size, - "feat_out": attention_dim, - "conv_channels": 256, - "subsampling_conv_chunking_factor": 1, - "activation": nn.ReLU(), - "is_causal": False, - } - # Override any of the defaults with the incoming, user settings - if nemo_conv_settings: - default_nemo_conv_settings.update(nemo_conv_settings) - for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" - - self.embed = NemoConvSubsampling( - **default_nemo_conv_settings, - ) - else: - raise ValueError("unknown input_layer: " + input_layer) - - self.pos_emb = AbsolutePositionalEncoding( - attention_dim, positional_dropout_rate - ) - - self.relative_attention_bias_type = ( - relative_attention_bias_args.get("type") - if relative_attention_bias_args - else None - ) - if self.relative_attention_bias_type == "t5": - assert ( - self.num_heads % self.attention_group_size == 0 - ), "attention_group_size must divide n_head" - self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( - self.num_heads // self.attention_group_size, - max_distance=relative_attention_bias_args.get( - "t5_bias_max_distance", 1000 - ), - symmetric=relative_attention_bias_args.get( - "t5_bias_symmetric", False - ), - ) - else: - raise NotImplementedError - - def post_init(self, init_model_config): - - pretrained_speech_encoder_path = init_model_config.get( - "pretrained_speech_encoder_path", None - ) - if pretrained_speech_encoder_path: - model_state = torch.load( - pretrained_speech_encoder_path, map_location="cpu" - ) - encoder_state_dict = {} - for k, v in model_state.items(): - if "encoder." in k: - tmp_k = k.replace("encoder.", "") - encoder_state_dict[tmp_k] = v - - if hasattr(self, "encoder_embedding"): - del self.encoder_embedding - self.load_state_dict(encoder_state_dict) - - if not hasattr(self, "encoder_embedding"): - self.encoder_embedding = MeanVarianceNormLayer( - self.encoder_embedding_config["input_size"] - ) - - def compute_lens_change(self, feature_lens): - """feature_lens: int - return updated feature lens. - - This used to return a different lambda function for each case that computed - the right thing. That does not work within Torchscript. If you really - need this to be faster, create nn.Module()-s for all the cases and return - one of them. Torchscript does support that. - """ - if self.input_layer == "nemo_conv": - # Handle the special causal case - subsampling_causal_cond = self.nemo_conv_settings.get( - "subsampling", "dw_striding" - ) in [ - "dw_striding", - "striding", - "striding_conv1d", - ] - is_causal = self.nemo_conv_settings.get("is_causal", False) - if is_causal and subsampling_causal_cond: - lens_change = ( - torch.ceil(feature_lens / self.time_reduction).long() - if isinstance(feature_lens, Tensor) - else math.ceil(feature_lens / self.time_reduction) - ) - feature_lens_remainder = feature_lens % self.time_reduction - if isinstance(feature_lens, Tensor): - lens_change[feature_lens_remainder != 1] += 1 - elif feature_lens_remainder != 1: - lens_change += 1 - return lens_change - ceil_func = ( - math.ceil if isinstance(feature_lens, int) else torch.ceil - ) - return ceil_func(feature_lens / self.time_reduction) - - @abc.abstractmethod - def forward(self): - """Abstract forward method implementation.""" - - def _chunk_size_selection(self, chunk_size=None, left_chunk=None): - """If chunk size is a list, we will randomly select a chunk size.""" - - if chunk_size is None: - chunk_size = self.chunk_size - if left_chunk is None: - left_chunk = self.left_chunk - if isinstance(chunk_size, list): - # Variable chunk size during training - chunk_size_index = int( - torch.randint(low=0, high=len(chunk_size), size=(1,)) - ) - chunk_size_train_eff = chunk_size[chunk_size_index] - if not isinstance(left_chunk, list): - raise ValueError( - "Since chunk_size is a list, left_chunk must be a list" - ) - if len(left_chunk) != len(chunk_size): - raise ValueError( - "The length of left_chunk must be the same as length of chunk_size." - ) - left_chunk_train_eff = left_chunk[chunk_size_index] - else: - chunk_size_train_eff = chunk_size - left_chunk_train_eff = left_chunk - - return chunk_size_train_eff, left_chunk_train_eff - - def _get_embed_class(self, embed): - # pylint: disable=protected-access - is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) - is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) - embed_class = embed - if is_embed_using_act_chkpt: - embed_class = embed._checkpoint_wrapped_module - if is_embed_fsdp_wrapped: - embed_class = embed.module - return embed_class - - def _forward_embeddings_core(self, input_tensor, masks): - embed_class = self._get_embed_class(self.embed) - assert isinstance(embed_class, NemoConvSubsampling) - input_tensor, masks = self.embed(input_tensor, masks) - return input_tensor, masks - - def _position_embedding(self, input_tensor): - pos_k = None - pos_v = None - if self.relative_attention_bias_layer is None: - input_tensor = self.pos_emb( - input_tensor - ) # default to add abs sinusoid embedding - return pos_k, pos_v - - def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): - chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( - chunk_size, left_chunk - ) - - # Create mask matrix for streaming - # S stores start index. if chunksize is 18, s is [0,18,36,....] - chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) - # avoid randomness when run evaluation or decoding - if self.training and np.random.rand() > 0.5: - # Either first or last chunk is not complete. - # If only the last one is not complete, EOS is not effective - chunk_start_idx = seq_len - chunk_start_idx - chunk_start_idx = chunk_start_idx[::-1] - chunk_start_idx = chunk_start_idx[:-1] - chunk_start_idx = np.insert(chunk_start_idx, 0, 0) - - enc_streaming_mask = ( - adaptive_enc_mask( - seq_len, chunk_start_idx, left_window=left_chunk_train_eff - ) - .unsqueeze(0) - .expand([batch_size, -1, -1]) - ) - return enc_streaming_mask - - def forward_embeddings( - self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None - ): - """Forwarding the inputs through the top embedding layers - - Args: - xs_pad: torch.Tensor - input tensor - masks: torch.Tensor - input mask - chunk_size_nc: (optional, default is None) chunk size for non-causal layers - left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers - """ - # pylint: disable=R0915 - # get new lens. - seq_len = int(self.compute_lens_change(xs_pad.shape[1])) - if seq_len <= 0: - raise ValueError( - f"""The squence length after time reduction is invalid: {seq_len}. - Your input feature is too short. Consider filtering out the very - short sentence from data loader""", - ) - - batch_size = xs_pad.shape[0] - - enc_streaming_mask = self._streaming_mask( - seq_len, batch_size, self.chunk_size, self.left_chunk - ) - - if xs_pad.is_cuda: - enc_streaming_mask = enc_streaming_mask.cuda() - xs_pad = xs_pad.cuda() - - input_tensor = xs_pad - input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) - - streaming_mask = enc_streaming_mask - if streaming_mask is not None and masks is not None: - hs_mask = masks & streaming_mask - elif masks is not None: - hs_mask = masks - else: - hs_mask = streaming_mask - - if chunk_size_nc is not None: - enc_streaming_mask_nc = self._streaming_mask( - seq_len, batch_size, chunk_size_nc, left_chunk_nc - ) - if xs_pad.is_cuda: - enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() - if masks is not None: - hs_mask_nc = masks & enc_streaming_mask_nc - else: - hs_mask_nc = enc_streaming_mask_nc - else: - hs_mask_nc = None - - pos_k, pos_v = self._position_embedding(input_tensor) - - if chunk_size_nc is None: - return input_tensor, pos_k, pos_v, hs_mask, masks - return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc - - def get_offset(self): - """Returns offset used when retaining inputs for decoding. - - This is essentially, how many additional frames have to be added to - the front-end CNN input to ensure it can produce a single output. - So if the "padding" parameter is 0, typically offset will be > 0. - """ - return get_offset(self.input_layer, self.time_reduction) - - def get_offset(input_layer: str, time_reduction: int): """Get an offset. We will use the offset for determining #frames of a subsampled feature. @@ -2661,777 +2014,3 @@ def unfold_tensor(xs_pad, max_seq_len): xs_pad = xs_pad.view(-1, max_seq_len, D) return xs_pad -class ConformerEncoder(TransformerEncoderBase): - """ConformerEncoder module. - see original paper for more details: - https://arxiv.org/abs/2005.08100 - - Please set causal = True in streaming model - Args: - input_size: int - input feature dimension. - chunk_size: int, list(int) - Number of frames for each chunk - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training - Some examples for the 2 cases: - chunk_size = 12 - chunk_size = [6, 8, 12, 24] - left_chunk: int, list(int) - Number of chunks used for masking in streaming mode. - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training. When - chunk_size is a list, left_chunk must be a list with same length. - Some examples for the 2 cases: - left_chunk = 6 - left_chunk = [12, 9, 6, 3] - left_chunk: int - number of chunks used for masking in streaming mode. - num_lang: int - This parameter is used to store the number of languages in the lang_dict, - only used for multiseed/multilingual models. default None. - attention_dim: int, optional - attention dimension. default 256. - attention_heads: int, optional - the number of heads. default 4 - linear_units: - the number of units of position-wise feed forward. - default 2048 - num_block: - number of Transformer layer. default 6 - dropout_rate: float, optional - dropout rate. default 0.1 - input_layer: str, optional - input layer type before Conformer, - one of ["linear", "conv2d", "custom", "vgg2l", "embed"], - default "conv2d" - causal: bool, optional - if set to True, convolution have no access - to future frames. default False. - batch_norm: bool, optional - if set to True, apply batchnorm before activation - in ConvModule layer of the conformer. - default False - cnn_out: int, optional - the number of CNN channels before Conformer. - default -1. - cnn_layer_norm: bool, optional - layer norm between Conformer and the first CNN. - default False. - ext_pw_out_channel: int, optional - the number of channel for CNN - before depthwise_seperable_CNN. - If 0 then use linear. default 0. - ext_pw_kernel_size: int, optional - kernel size of N before depthwise_seperable_CNN. - only work for ext_pw_out_channel > 0. - default 1 - depthwise_seperable_out_channel: int, optional - the number of channel for - depthwise_seperable_CNN. - default 256. - depthwise_multiplier: int, optional - the number of multiplier for - depthwise_seperable_CNN. - default 1. - chunk_se: int, optional - 0 for offline SE. - 1 for streaming SE, where mean is computed - by accumulated history until current chunk_se. - 2 for streaming SE, where mean is computed - by only the current chunk. - default 0. - kernel_size: int, optional - the number of kernels for depthwise_seperable_CNN. - default 3. - activation: str, optional - FeedForward block activation. - one of ["relu", "swish", "sigmoid"] - default "relu". - conv_activation: str, optional - activation function used in ConvModule part - of the conformer, default "relu". - conv_glu_type: str, otional - activation used use glu in depthwise_seperable_CNN, - default "sigmoid" - bias_in_glu: bool, optional - if set to True, use additive bias in the weight module - before GLU. default True - linear_glu_in_convm: bool, optional - if set to True, use GLULinear module, - otherwise, used GLUPointWiseConv module. - default to False. - attention_glu_type: str - only work for glu_in_attention !=0 - default "swish". - export: bool, optional - if set to True, it remove the padding from convolutional layers - and allow the onnx conversion for inference. - default False. - activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where - "module": str - accept ["transformer", "attention"] to select - which module should do activation checkpointing. - "interval": int, default 1, - interval of applying activation checkpointing, - interval = 1 means that we apply checkpointing - on every layer (if activation), otherwise, - we apply it every x interval. - "offload": bool, default False, - if set to True, we offload activation to cpu and - reload it during backward, otherwise, - we recalculate activation in backward. - default "". - extra_layer_output_idx: int - the layer index to be exposed. - relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias - usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) - time_reduction: int optional - time reduction factor - default 4 - use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention - in training. - Default: False - nemo_conv_settings: dict, optional - A dictionary of settings for NeMo Subsampling. - default: None - usage: nemo_conv_settings= - { - "subsampling": - dw_striding/striding/dw_striding_conv1d/striding_conv1d, - "conv_channels": int, - "subsampling_conv_chunking_factor": int, - "is_causal": True/False - } - conv2d_extra_padding: str, optional - Add extra padding in conv2d subsampling layers. Choices are - (feat, feat_time, none, True) - Default: none - replication_pad_for_subsample_embedding: For batched-streaming decoding, use - "replication" padding for the cache at start of utterance. - Default: False - attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention - attention_group_size = attenion_heads = Multi-Query Attention - """ - - extra_multi_layer_output_idxs: List[int] - - def __init__( # pylint: disable-all - self, - input_size, - chunk_size, - left_chunk, - num_lang=None, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - input_layer="nemo_conv", - causal=True, - batch_norm=False, - cnn_out=-1, - cnn_layer_norm=False, - ext_pw_out_channel=0, - ext_pw_kernel_size=1, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - chunk_se=0, - kernel_size=3, - activation="relu", - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_glu_type="swish", - export=False, - extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], - activation_checkpointing="", - relative_attention_bias_args=None, - time_reduction=4, - use_pt_scaled_dot_product_attention=False, - nemo_conv_settings=None, - conv2d_extra_padding: Literal[ - "feat", "feat_time", "none", True - ] = "none", - replication_pad_for_subsample_embedding=False, - attention_group_size=1, - encoder_embedding_config=None, - ): - super().__init__( - input_size, - chunk_size, - left_chunk, - attention_dim, - attention_heads, - input_layer, - cnn_out, - cnn_layer_norm, - time_reduction, - dropout_rate=dropout_rate, - relative_attention_bias_args=relative_attention_bias_args, - positional_dropout_rate=0.0, - nemo_conv_settings=nemo_conv_settings, - conv2d_extra_padding=conv2d_extra_padding, - attention_group_size=attention_group_size, - encoder_embedding_config=encoder_embedding_config, - ) - self.num_blocks = num_blocks - self.num_lang = num_lang - self.kernel_size = kernel_size - self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( - self.embed - ) - self.replication_pad_for_subsample_embedding: bool = ( - replication_pad_for_subsample_embedding - ) - assert ( - self.num_heads % attention_group_size == 0 - ), "attention_group_size must divide n_head" - self.num_heads_k = self.num_heads // attention_group_size - - self.encoders = repeat( - num_blocks, - lambda i: encoder_checkpoint_wrapper( - activation_checkpointing, ConformerEncoderLayer, i - )( - ConformerEncoderLayer( - d_model=attention_dim, - ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel=depthwise_seperable_out_channel, - depthwise_multiplier=depthwise_multiplier, - n_head=attention_heads, - d_ffn=linear_units, - ext_pw_kernel_size=ext_pw_kernel_size, - kernel_size=kernel_size, - dropout_rate=dropout_rate, - causal=causal, - batch_norm=batch_norm, - activation=activation, - chunk_se=chunk_se, - chunk_size=chunk_size, - conv_activation=conv_activation, - conv_glu_type=conv_glu_type, - bias_in_glu=bias_in_glu, - linear_glu_in_convm=linear_glu_in_convm, - attention_glu_type=attention_glu_type, - activation_checkpointing=attn_checkpointing( - activation_checkpointing, i - ), - export=export, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - attn_group_sizes=attention_group_size, - ) - ), - ) - self.extra_layer_output_idx = extra_layer_output_idx - self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs - # Make a zeros scalar we can use in get_initial_state to determine - # the device and the needed dtype: - self.register_buffer("dev_type", torch.zeros(()), persistent=False) - - def init_relative_attention_bias(self, input_tensor): - if self.relative_attention_bias_layer: - return self.relative_attention_bias_layer(input_tensor) - - def calculate_hs_mask(self, xs_pad, device, mask): - max_audio_length = xs_pad.shape[1] - batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask( - max_audio_length, batch_size, self.chunk_size, self.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(device) - if mask is None: - return enc_streaming_mask - - feature_lens = mask.sum(1) - padding_length = feature_lens - pad_mask = ( - torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) - < padding_length.unsqueeze(1) - ) - pad_mask = pad_mask.unsqueeze(1) - pad_mask = pad_mask & enc_streaming_mask - return pad_mask - - @torch.jit.ignore - def forward(self, xs_pad, masks): - """Conformer Forward function - - Args: - xs_pad: torch.Tensor - input tensor - masks: torch.Tensor - post-embedding input lengths - """ - xs_pad = self.encoder_embedding(xs_pad) - input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( - xs_pad, masks - ) - - unfolded = False - ori_bz, seq_len, D = input_tensor.shape - max_seq_len = 500 #maxium position for absolute positional encoding - if seq_len > max_seq_len: - # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len - unfolded = True - # the unfold op will drop residual frames, pad it to the multiple of max_seq_len - if seq_len % max_seq_len > 0: - chunk_pad_size = max_seq_len - (seq_len % max_seq_len) - else: - chunk_pad_size = 0 - if chunk_pad_size > 0: - input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) - input_tensor = input_tensor_pad.to(input_tensor.device) - input_tensor = unfold_tensor(input_tensor, max_seq_len) - if masks is not None: - # revise hs_mask here because the previous calculated hs_mask did not consider extra pad - subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() - masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor - masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor - else: - masks_unfold = None - hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask - - layer_emb = None - - relative_attention_bias = self.init_relative_attention_bias( - input_tensor - ) - - _simplified_path = ( - self.extra_layer_output_idx == -1 - and relative_attention_bias is None - ) - - if _simplified_path: - input_tensor, *_ = self.encoders( - input_tensor, pos_k, pos_v, hs_mask - ) - else: - for i, layer in enumerate(self.encoders): - input_tensor, _, _, _ = layer( - input_tensor, - pos_k, - pos_v, - hs_mask, - relative_attention_bias=relative_attention_bias, - ) - - if i == self.extra_layer_output_idx: - layer_emb = input_tensor - - if unfolded: - embed_dim = input_tensor.shape[-1] - input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) - # if we ever padded before unfolding, we need to remove the padding - if chunk_pad_size > 0: - input_tensor = input_tensor[:, :-chunk_pad_size, :] - - return input_tensor, masks # , layer_emb - - def gradient_checkpointing_enable(self): - pass - - -class WindowQformer(nn.Module): - """Window-level Qformer""" - - def __init__( - self, - window_size: int = 8, - num_queries: int = 1, - num_blocks: int = 2, - attention_dim: int = 512, - attention_heads: int = 8, - linear_units: int = 2048, - dropout_rate: float = 0.0, - normalize_before: bool = True, - ): - super().__init__() - - self.decoders = nn.ModuleList( - [ - nn.TransformerDecoderLayer( - d_model=attention_dim, - nhead=attention_heads, - dim_feedforward=linear_units, - dropout=dropout_rate, - activation="relu", - batch_first=True, - norm_first=normalize_before, # TODO need to verify - ) - for _ in range(num_blocks) - ] - ) - - self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) - self.after_norm = ( - nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None - ) - self.window_size = window_size - self.gradient_checkpointing_enable = False - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing_enable = True - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing_enable = False - - def forward(self, audio_embed, mask, embed_len=None): - """forward decoder""" - # audio_embed: N x T x D => N x D x T - - audio_embed = audio_embed.transpose(1, 2) - # audio_embed: N x D x 1 x T => N x DK x T' - padding = audio_embed.shape[-1] % self.window_size - if padding > 0: - audio_embed = F.pad( - audio_embed, (0, self.window_size - padding), "constant", 0 - ) - - embed_chunk = F.unfold( - audio_embed[..., None, :], - kernel_size=(1, self.window_size), - stride=(1, self.window_size), - ) - bsz, _, slen = embed_chunk.shape - # N x D x K x T' - embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) - # N x T' x K x D - embed_chunk = embed_chunk.transpose(1, 3).contiguous() - # NT' x K x D - embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) - # NT' x 1 x D - q = self.queries.expand(bsz * slen, -1, -1) - for layer in self.decoders: - if self.gradient_checkpointing_enable and self.training: - q = checkpoint( - layer.__call__, - q, - embed_chunk, - None, - mask, - use_reentrant=True, - ) - else: - q = layer( - tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask - ) - - if self.after_norm is not None: - q = self.after_norm(q) - - if embed_len is not None: - embed_len = embed_len // self.window_size - # N x T' x D - out = q.view(bsz, slen, -1) - - return out, embed_len - - -class AudioEmbedding(nn.Module): - """Image embedding.""" - - def __init__(self, config: PretrainedConfig, **kwargs) -> None: - super().__init__() - self.config = config - # n_embed or hidden_size for text LM - hidden_size = ( - config.n_embd if hasattr(config, "n_embd") else config.hidden_size - ) - - if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): - embd_drop = ( - config.embd_pdrop - if hasattr(config, "embd_pdrop") - else config.embed_pdrop - ) - self.drop = nn.Dropout(embd_drop) - else: - self.drop = None - - # self.wte = nn.Embedding(config.vocab_size, hidden_size) - - audio_dim_out = ( - None # Set this variable according to the actual audio processor - ) - self.layer_idx = -2 - - # if isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == 'whisper': - # model_path = config.audio_processor.get('pretrained_model_path', None) - # whisper_model = WhisperModel.from_pretrained(model_path) - - # self.encoder = whisper_model.encoder - # n_mels = self.encoder.num_mel_bins - # audio_dim_out = self.encoder.layers[0].embed_dim - # elif isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == "w2vbert2": - # audio_processor_path = config.audio_processor.get("model_path", "facebook/w2v-bert-2.0") - # self.encoder = Wav2Vec2BertModel.from_pretrained(audio_processor_path) - # audio_dim_out = self.encoder.config.hidden_size - # self.layer_idx = config.audio_processor.get("layer", 18) - # self.encoder.config.apply_spec_augment = False - # self.encoder.config.mask_time_prob = 0 - # self.encoder.config.output_hidden_states = True - # n_mels = 160 - if ( - isinstance(config.audio_processor, dict) - and config.audio_processor.get("name", None) == "cascades" - ): - encoder_config = config.audio_processor.get("config", None) - assert encoder_config is not None - self.encoder = ConformerEncoder(**encoder_config) - - # fake initialization, create encoder_embedding layer only so that - # in decoding, all parameters can be loaded in from_pretrained_function - # in training, we do post init after from_pretrained function to make sure the correct initialization - self.encoder.post_init({}) - - audio_dim_out = encoder_config["attention_dim"] - n_mels = encoder_config["input_size"] - else: - raise NotImplementedError(f"") - - assert ( - audio_dim_out is not None - ), "Remember to set values for audio_dim_out" - self.audio_dim_out = audio_dim_out - self.audio_dim_in = n_mels - - self.freeze_audio_processor = kwargs.get( - "freeze_audio_processor", False - ) - - self.downsample_rate = kwargs.get("downsample_rate", 1) - - if kwargs.get("use_qformer", False): - qformer_config = kwargs.get("qformer_config", {}) - qformer_config["attention_dim"] = audio_dim_out - self.qformer = WindowQformer(**qformer_config) - else: - self.qformer = None - - if kwargs.get("use_conv_downsample", False): - assert ( - self.qformer is None - ), "don't support use qformer and conv downsample together" - nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) - default_nemo_conv_settings = { - "subsampling": "dw_striding", - "subsampling_factor": self.downsample_rate, - "feat_in": audio_dim_out, - "feat_out": audio_dim_out, - "conv_channels": 256, - "subsampling_conv_chunking_factor": 1, - "activation": nn.ReLU(), - "is_causal": False, - } - # Override any of the defaults with the incoming, user settings - if nemo_conv_settings: - default_nemo_conv_settings.update(nemo_conv_settings) - for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" - - self.conv_ds = NemoConvSubsampling( - **default_nemo_conv_settings, - ) - else: - self.conv_ds = None - - enable_gradient_checkpointing = kwargs.get( - "enable_gradient_checkpointing", False - ) - if enable_gradient_checkpointing: - self.encoder.gradient_checkpointing_enable() - - if self.qformer: - self.qformer.enable_gradient_checkpointing() - - projection_cls = kwargs.get("projection_cls", "linear") - if projection_cls == "linear": - self.audio_projection = nn.Linear(audio_dim_out, hidden_size) - elif projection_cls == "mlp": - # follow llava-v1.5's implementation - # (do not use image_projection and image_proj_norm) - dim_projection = hidden_size - depth = 2 - self.linear_downsample_rate = ( - 1 if (self.qformer or self.conv_ds) else self.downsample_rate - ) - layers = [ - nn.Linear( - audio_dim_out * self.linear_downsample_rate, dim_projection - ) - ] - for _ in range(1, depth): - layers.extend( - [nn.GELU(), nn.Linear(dim_projection, dim_projection)] - ) - self.audio_projection = nn.Sequential(*layers) - # NOTE vision-speech tasks use a seperate projection layer - layers = [ - nn.Linear( - audio_dim_out * self.linear_downsample_rate, dim_projection - ) - ] - for _ in range(1, depth): - layers.extend( - [nn.GELU(), nn.Linear(dim_projection, dim_projection)] - ) - self.audio_projection_for_vision = nn.Sequential(*layers) - else: - raise NotImplementedError( - f"projection_cls = {projection_cls}, not implemented" - ) - - # TODO: audio sequence compression - Qformer - self.vocab_size = config.vocab_size - self.input_embeds = None - self.audio_embed_sizes = None - - def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: - self.input_embeds = input_embeds - - def set_audio_embed_sizes( - self, audio_embed_sizes: torch.LongTensor - ) -> None: - self.audio_embed_sizes = audio_embed_sizes - - def get_audio_features( - self, - input_embeds: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, - audio_projection_mode: str = "speech", - ): - - if self.freeze_audio_processor: - with torch.no_grad(): - audio_features, masks = self.encoder( - input_embeds, audio_attention_mask - ) - else: - audio_features, masks = self.encoder( - input_embeds, audio_attention_mask - ) - - if self.qformer is not None: - audio_features, _ = self.qformer(audio_features, mask=None) - - if self.conv_ds is not None: - if masks is not None: - masks = masks.squeeze(1) - - audio_features, masks = self.conv_ds(audio_features, mask=masks) - - if self.linear_downsample_rate != 1: - bs, seq_len, feat_dim = audio_features.size() - padding = seq_len % self.linear_downsample_rate - if padding > 0: - audio_features = F.pad( - audio_features, - (0, 0, 0, self.linear_downsample_rate - padding), - "constant", - 0, - ) - - seq_len = audio_features.size(1) - audio_features = audio_features.view( - bs, - seq_len // self.linear_downsample_rate, - feat_dim * self.linear_downsample_rate, - ) - - if audio_projection_mode == 'speech': - audio_set_tensor = self.audio_projection(audio_features) - elif audio_projection_mode == 'vision': - audio_set_tensor = self.audio_projection_for_vision(audio_features) - else: - raise ValueError(f"audio_projection_mode = {audio_projection_mode} not implemented") - - return audio_set_tensor - - def forward( - self, - input_ids: torch.LongTensor, - input_embeds: torch.FloatTensor, - audio_embed_sizes, - **kwargs, - ) -> torch.FloatTensor: - """ - arguments: - input_ids: input text ids (B, U) - input_embeds: audio features (B, T, D) B: num audios in a sequence - """ - assert input_embeds is not None and len(input_embeds) == len( - audio_embed_sizes - ) - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - with torch.no_grad(): - positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=False) - - if not isinstance(input_embeds, list): - input_embeds = [input_embeds] - - audio_projection_mode = kwargs.get("audio_projection_mode", "speech") - audio_set_tensor = [ - self.get_audio_features(input_embed, audio_projection_mode=audio_projection_mode) - for input_embed in input_embeds - ] - - with torch.no_grad(): - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - - if "wte" in kwargs: - # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. - hidden_states = kwargs["wte"](input_ids) - else: - # otherwise, we use token embedding in pretrained mixformer from phi team - hidden_states = self.wte(input_ids) - - if len(positions.tolist()) > 0: - assert sum(audio_embed_sizes) == len( - positions - ), "please ensure the encoder outputs have the same length as defined in input_ids!" - idx = 0 - for i in range(len(audio_embed_sizes)): - cnt = audio_embed_sizes[i] - assert audio_set_tensor[i].shape[0] == 1 - hidden_states[ - positions[idx, 0], - positions[idx, 1] : positions[idx, 1] + cnt, - ] = ( - audio_set_tensor[i][0, : audio_embed_sizes[i], :] - .to(hidden_states.dtype) - .to(hidden_states.device) - ) - idx += cnt - - else: - if self.training: - # hidden_states[:, 0:img_set_tensor.shape[0]] = hidden_states[:, 0:img_set_tensor.shape[0]] + 0 * img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) - hidden_states[:, 0:1] = hidden_states[ - :, 0:1 - ] + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype).to( - hidden_states.device - ) - - if self.drop is not None: - hidden_states = self.drop(hidden_states) - return hidden_states From 40011bb8992f2942ca15aa3d0cf3abb190cf481e Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sun, 2 Mar 2025 23:59:13 -0800 Subject: [PATCH 14/27] remove flash_attn from requirements-common.txt Signed-off-by: Congcong Chen --- requirements-common.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 77a8ff701817..690484d7ecef 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -38,5 +38,4 @@ depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md -scipy # Required for phi-4-multimodal-instruct -flash_attn # Required for phi-4-multimodal-instruct \ No newline at end of file +scipy # Required for phi-4-multimodal-instruct \ No newline at end of file From 120013288bbe8c4617c03ddaeac1da84368c89d7 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 00:01:28 -0800 Subject: [PATCH 15/27] Add more max LoRA rank support Signed-off-by: Congcong Chen --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 8714b50c6c1a..3f1bff498129 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2286,7 +2286,7 @@ def compute_hash(self) -> str: def __post_init__(self): # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 512) + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( From 4b11f1045aa85a0ca98c2eba01e61a406798e4dc Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 15:33:47 -0800 Subject: [PATCH 16/27] format code Signed-off-by: Congcong Chen --- examples/offline_inference_phi3o.py | 210 ++-- requirements-test.txt | 22 +- vllm/model_executor/models/phi4mm.py | 942 ++++++++------ vllm/model_executor/models/phi4mm_audio.py | 619 +++++----- vllm/model_executor/models/phi4mm_utils.py | 735 ++++++----- .../models/vision_siglip_navit.py | 1080 ++++++++++------- 6 files changed, 1973 insertions(+), 1635 deletions(-) diff --git a/examples/offline_inference_phi3o.py b/examples/offline_inference_phi3o.py index 9df355a4be19..6166254db213 100644 --- a/examples/offline_inference_phi3o.py +++ b/examples/offline_inference_phi3o.py @@ -1,45 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 # Implements a simple offline inference script for the Phi 3.5 Speech model. # Code implemented by Jacob Platin (jacobplatin@microsoft.com) import soundfile from vllm import LLM, SamplingParams -from vllm.utils import FlexibleArgumentParser from vllm.lora.request import LoRARequest from vllm.multimodal.utils import fetch_image +from vllm.utils import FlexibleArgumentParser + -""" -Model file: vllm/model_executor/models/phi3o.py - -Step 1: Download the following model weights to some location. -* Base Model Weight: https://github.com/microsoft/MoE/tree/weijian/phio-hf -* Vision Lora Model Weight: https://llmpretrainingwus3.blob.core.windows.net/users/weijianxu/phio-004-sft-vision-lora-only-from-hf-unified-model -* Speech Lora Model Weight: https://llmpretrainingwus3.blob.core.windows.net/users/weijianxu/phio-004-sft-speech-lora-only-from-hf-unified-model - -Step 2: Run the test -* Run the followling command with the commandline parameters you want to pass into the script. - python examples/offline_inference_phi3s.py -* You should expect to see the output like: - Prompt: '<|user|>\n<|image_1|>\n<|audio_1|>\ntry your best to answer the question<|end|>\n<|assistant|>\n' - Generated text: 'Stop' -""" def main_pure_text(args: dict) -> None: """ Main function for the offline inference script. """ - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enforce_eager=True) + llm = LLM(model=args.model_path, + trust_remote_code=True, + enforce_eager=True) user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' prompt_suffix = '<|end|>\n' - prompt = f'{user_prompt}what is the answer for 1+1? Explain it.{prompt_suffix}{assistant_prompt}' + prompt = f'{user_prompt}what is the answer for 1+1? Explain'\ + f' it.{prompt_suffix}{assistant_prompt}' print(f'>>> Prompt\n{prompt}') # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = { - "prompt": prompt - } + generate_args = {"prompt": prompt} # NOTE: you should use the following settings to ensure parity in HF # generate_ids = model.generate( # **inputs, @@ -68,22 +53,22 @@ def main_with_lora_speech(args: dict, activate_lora_request=True) -> None: Main function for the offline inference script. """ wav_paths = [args.wav_path] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - limit_mm_per_prompt={"audio": len(wav_paths)}, - max_loras=5) + llm = LLM(model=args.model_path, + trust_remote_code=True, + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + limit_mm_per_prompt={"audio": len(wav_paths)}, + max_loras=5) # assert len(wav_paths) == 1, "Only support single audio files for now!" - prompt = "Generate a comprehensive text transcription of the spoken content." - placeholders = "\n".join( - f"<|audio_{i}|>" for i in range(1, len(wav_paths) + 1) - ) + prompt = "Generate a comprehensive text transcription of the "\ + "spoken content." + placeholders = "\n".join(f"<|audio_{i}|>" + for i in range(1, + len(wav_paths) + 1)) prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" # NOTE: soundfile.read will return the audio feature and the sampling rate @@ -108,36 +93,41 @@ def main_with_lora_speech(args: dict, activate_lora_request=True) -> None: max_tokens=200, ) - outputs = llm.generate(generate_args, sampling_params=sampling_params, lora_request= [LoRARequest("speech_adapter", 3, args.speech_lora_path)] if activate_lora_request else None) + outputs = llm.generate( + generate_args, + sampling_params=sampling_params, + lora_request=[LoRARequest("speech_adapter", 3, args.speech_lora_path)] + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}") print(f"Generated text: {generated_text!r}\n\n") -def main_with_lora_speech_batch(args: dict, activate_lora_request=True) -> None: + +def main_with_lora_speech_batch(args: dict, + activate_lora_request=True) -> None: """ Main function for the offline inference script. """ wav_paths = [args.wav_path, args.wav_path] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - limit_mm_per_prompt={"audio": len(wav_paths)}, - max_loras=5) - + llm = LLM(model=args.model_path, + trust_remote_code=True, + enable_lora=activate_lora_request, + enforce_eager=True, + max_lora_rank=512, + lora_extra_vocab_size=0, + limit_mm_per_prompt={"audio": len(wav_paths)}, + max_loras=5) # assert len(wav_paths) == 1, "Only support single audio files for now!" - prompt = "Based on the attached audio, generate a comprehensive text transcription of the spoken content." - placeholders = "\n".join( - f"<|audio_{i}|>" for i in range(1, len(wav_paths) + 1) - ) + prompt = "Based on the attached audio, generate a comprehensive text "\ + "transcription of the spoken content." + placeholders = "\n".join(f"<|audio_{i}|>" + for i in range(1, + len(wav_paths) + 1)) prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" # NOTE: soundfile.read will return the audio feature and the sampling rate @@ -173,23 +163,23 @@ def main_with_lora_speech_batch(args: dict, activate_lora_request=True) -> None: outputs = llm.generate( generate_args, sampling_params=sampling_params, - lora_request= LoRARequest("speech_adapter", 3, args.speech_lora_path) - if activate_lora_request else None) + lora_request=LoRARequest("speech_adapter", 3, args.speech_lora_path) + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}") print(f"Generated text: {generated_text!r}\n\n") + def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: """ Main function for the offline inference script. """ - image_urls=[args.image_url] + image_urls = [args.image_url] llm = LLM( model=args.model_path, trust_remote_code=True, - enable_lora=activate_lora_request, enforce_eager=True, max_lora_rank=512, @@ -206,7 +196,7 @@ def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" - image_data=[fetch_image(url) for url in image_urls] + image_data = [fetch_image(url) for url in image_urls] # NOTE: soundfile.read will return the audio feature and the sampling rate generate_args = { @@ -233,8 +223,8 @@ def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: outputs = llm.generate( generate_args, sampling_params=sampling_params, - lora_request= [LoRARequest("vision_adapter", 3, args.vision_lora_path)] if activate_lora_request else None - ) + lora_request=[LoRARequest("vision_adapter", 3, args.vision_lora_path)] + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text @@ -242,15 +232,18 @@ def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: print(f"Generated text: {generated_text!r}\n\n") -def main_with_lora_vision_batch(args: dict, activate_lora_request=True) -> None: +def main_with_lora_vision_batch(args: dict, + activate_lora_request=True) -> None: """ Main function for the offline inference script. """ - image_urls=[args.image_url, "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"] + image_urls = [ + args.image_url, + "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg" + ] llm = LLM( model=args.model_path, trust_remote_code=True, - enable_lora=activate_lora_request, enforce_eager=True, max_lora_rank=512, @@ -275,7 +268,13 @@ def main_with_lora_vision_batch(args: dict, activate_lora_request=True) -> None: { "prompt": prompt, "multi_modal_data": { - "image": [fetch_image(url) for url in ["https://www.ilankelman.org/stopsigns/australia.jpg", "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"]], + "image": [ + fetch_image(url) for url in [ + "https://www.ilankelman.org/stopsigns/australia.jpg", + "https://alinasayre.com/wp-content/uploads/2013/10/"\ + "d67cd-dsc01646.jpg" + ] + ], }, }, { @@ -303,9 +302,8 @@ def main_with_lora_vision_batch(args: dict, activate_lora_request=True) -> None: outputs = llm.generate( generate_args, sampling_params=sampling_params, - lora_request= LoRARequest("vision_adapter", 3, args.vision_lora_path) - if activate_lora_request else None - ) + lora_request=LoRARequest("vision_adapter", 3, args.vision_lora_path) + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text @@ -313,15 +311,15 @@ def main_with_lora_vision_batch(args: dict, activate_lora_request=True) -> None: print(f"Generated text: {generated_text!r}\n\n") -def main_with_lora_vision_speech(args: dict, activate_lora_request=True) -> None: +def main_with_lora_vision_speech(args: dict, + activate_lora_request=True) -> None: """ Main function for the offline inference script. """ - image_urls=[args.image_url] + image_urls = [args.image_url] llm = LLM( model=args.model_path, trust_remote_code=True, - enable_lora=activate_lora_request, enforce_eager=True, max_lora_rank=512, @@ -333,15 +331,19 @@ def main_with_lora_vision_speech(args: dict, activate_lora_request=True) -> None limit_mm_per_prompt={"image": len(image_urls)}, ) - prompt = "" + prompt = "" placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>\n<|assistant|>\n" + prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>"\ + "\n<|assistant|>\n" - image_data=[fetch_image(url) for url in image_urls] + image_data = [fetch_image(url) for url in image_urls] - wav_paths = ["/scratch/turing_westus3_prm_data/users/congcongchen/MoE_2/hf-models/phio/examples/what_is_the_traffic_sign_in_the_image.wav"] + wav_paths = [ + "/scratch/turing_westus3_prm_data/users/congcongchen/MoE_2/hf-models"\ + "/phio/examples/what_is_the_traffic_sign_in_the_image.wav" + ] # NOTE: soundfile.read will return the audio feature and the sampling rate generate_args = { "prompt": prompt, @@ -368,24 +370,28 @@ def main_with_lora_vision_speech(args: dict, activate_lora_request=True) -> None outputs = llm.generate( generate_args, sampling_params=sampling_params, - lora_request= [LoRARequest("vision_adapter", 3, args.vision_lora_path)] if activate_lora_request else None - ) + lora_request=[LoRARequest("vision_adapter", 3, args.vision_lora_path)] + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}") print(f"Generated text: {generated_text!r}\n\n") -def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) -> None: + +def main_with_lora_vision_speech_batch(args: dict, + activate_lora_request=True) -> None: """ Main function for the offline inference script. """ - image_urls=[args.image_url, "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg"] + image_urls = [ + args.image_url, + "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg" + ] wav_paths = [args.wav_path] llm = LLM( model=args.model_path, trust_remote_code=True, - enable_lora=activate_lora_request, enforce_eager=True, max_lora_rank=512, @@ -394,18 +400,21 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - # max_model_len=40960, # max_num_seqs=5, - limit_mm_per_prompt={"image": len(image_urls), "audio": len(wav_paths)}, + limit_mm_per_prompt={ + "image": len(image_urls), + "audio": len(wav_paths) + }, ) - prompt = "try your best to answer the question" + prompt = "try your best to answer the question" placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>\n<|assistant|>\n" + prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}"\ + "<|end|>\n<|assistant|>\n" # image_data=[fetch_image(url) for url in image_urls] - # NOTE: soundfile.read will return the audio feature and the sampling rate generate_args = [ { @@ -413,14 +422,21 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "multi_modal_data": { "image": [fetch_image(url) for url in image_urls], "audio": [soundfile.read(wav_path) for wav_path in wav_paths], - }, + }, }, { "prompt": prompt, "multi_modal_data": { - "image": [fetch_image(url) for url in ["https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg", "https://alinasayre.com/wp-content/uploads/2012/01/c3a7c-dsc01668.jpg"]], + "image": [ + fetch_image(url) for url in [ + "https://alinasayre.com/wp-content/uploads/"\ + "2013/10/d67cd-dsc01646.jpg", + "https://alinasayre.com/wp-content/uploads/"\ + "2012/01/c3a7c-dsc01668.jpg" + ] + ], "audio": [soundfile.read(wav_path) for wav_path in wav_paths], - }, + }, }, ] # NOTE: you should use the following settings to ensure parity in HF @@ -441,9 +457,8 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - outputs = llm.generate( generate_args, sampling_params=sampling_params, - lora_request= LoRARequest("vision_adapter", 3, args.vision_lora_path) - if activate_lora_request else None - ) + lora_request=LoRARequest("vision_adapter", 3, args.vision_lora_path) + if activate_lora_request else None) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text @@ -454,13 +469,13 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - if __name__ == "__main__": parser = FlexibleArgumentParser( description="Demo on using vLLM for offline inference with " - "vision language models that support multi-image input" - ) + "vision language models that support multi-image input") parser.add_argument( "--model-path", "-p", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm", + default= + "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm", help="Path to the (HuggingFace) model checkpoint.", ) @@ -468,7 +483,8 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--vision-lora-path", "-v", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/vision-lora", + default= + "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/vision-lora", help="Path to the (HuggingFace) vision lora model checkpoint.", ) @@ -476,7 +492,8 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--speech-lora-path", "-s", type=str, - default="/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/speech-lora", + default= + "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/speech-lora", help="Path to the (HuggingFace) speech lora model checkpoint.", ) @@ -493,7 +510,8 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - "--image-url", "-i", type=str, - default="https://www.ilankelman.org/stopsigns/australia.jpg", + default= + "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg", ) parser.add_argument( @@ -508,7 +526,7 @@ def main_with_lora_vision_speech_batch(args: dict, activate_lora_request=True) - test_type = args.test_type if test_type == "language_only": main_pure_text(args) - ##### Speech + Language ##### + ##### Speech + Language ##### elif test_type == "speech_language_with_lora": main_with_lora_speech(args) elif test_type == "speech_language_with_lora_batch": diff --git a/requirements-test.txt b/requirements-test.txt index f5722c82e201..e5bf67e099e4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -23,6 +23,10 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator +async-timeout==4.0.3 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -116,6 +120,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -544,9 +552,7 @@ sentence-transformers==3.2.1 sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 - # via - # pytablewriter - # torch + # via pytablewriter six==1.16.0 # via # python-dateutil @@ -591,6 +597,12 @@ timm==1.0.11 # via -r requirements-test.in tokenizers==0.21.0 # via transformers +toml==0.10.2 + # via datamodel-code-generator +tomli==2.2.1 + # via + # black + # pytest torch==2.5.1 # via # -r requirements-test.in @@ -651,13 +663,17 @@ typepy==1.3.2 # tabledata typing-extensions==4.12.2 # via + # anyio # bitsandbytes + # black # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch tzdata==2024.2 # via pandas diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index d6313bf6ebfa..27ae9bcca2e4 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1,53 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 import math -from functools import lru_cache import re -from typing import ( - Dict, - Iterable, - List, - Literal, - Mapping, - Optional, - Tuple, - TypedDict, - Union, -) +from functools import lru_cache +from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import numpy as np import scipy.signal import torch import torch.nn as nn import torchvision.transforms as T -from transformers import PretrainedConfig from PIL import Image +from transformers import PretrainedConfig +from transformers.utils import logging -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext, DummyData -from vllm.inputs.data import token_inputs, TokenInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) +from vllm.inputs.data import TokenInputs, token_inputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE -) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalInputs, NestedTensors -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.sequence import IntermediateTensors, SequenceData -from transformers.utils import logging +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA -from .vision_siglip_navit import get_siglip_vision_model +from .interfaces import SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import PPMissingLayer, maybe_prefix - +from .utils import maybe_prefix +from .vision_siglip_navit import get_siglip_vision_model -_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 # <|endoftext10|> (see vocab.json in hf model) -_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> +# <|endoftext10|> (see vocab.json in hf model) +_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 +# <|endoftext11|> +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz @@ -66,11 +58,15 @@ }, } logger = logging.get_logger(__name__) -# This is a workaround to prevent text (user input) + audio + image from being used in -# the same prompt. -# It includes token ids for "/n" and tokens in added_tokens_decoder from the -# tokenizer_confg.json file. -NON_USER_INPUT_TOKENS = {198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, 200023, 200024, 200025, 200026, 200027, 200028} +# This is a workaround to prevent text (user input) + audio + image +# from being used in the same prompt. +# It includes token ids for "/n" and tokens in added_tokens_decoder +# from the tokenizer_confg.json file. +NON_USER_INPUT_TOKENS = { + 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, + 200023, 200024, 200025, 200026, 200027, 200028 +} + def get_max_dummy_image(ctx: InputContext): hf_config = ctx.get_hf_config() @@ -100,17 +96,15 @@ def get_max_phi4mm_image_tokens(ctx: InputContext): vit_patch_size = prepro_config['vit_patch_size'] token_compression_factor = prepro_config['token_compression_factor'] - image_num_tokens = _compute_num_image_tokens( - dummy_image, - dynamic_hd_size, - vit_image_size, - vit_patch_size, - token_compression_factor - ) + image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, + vit_image_size, + vit_patch_size, + token_compression_factor) return image_num_tokens -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height @@ -135,18 +129,14 @@ def _find_target_aspect_ratio(image, image_size, max_num, min_num): aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio - target_ratios = set( - (i, j) - for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num - ) + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size - ) + aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] @@ -173,11 +163,16 @@ def _get_padding_size(image, target_height, target_width): return padding_height, padding_width -def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, mask_size=27): - target_aspect_ratio, target_height, target_width = _find_target_aspect_ratio( - image, image_size, max_num, min_num - ) - padding_height, padding_width = _get_padding_size(image, target_height, target_width) +def dynamic_preprocess(image, + min_num=1, + max_num=12, + image_size=384, + mask_size=27): + target_aspect_ratio, target_height, target_width =\ + _find_target_aspect_ratio( + image, image_size, max_num, min_num) + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) # Calculate the ratio orig_width, orig_height = image.size @@ -188,48 +183,65 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, mask_size=2 else: new_size = (int(orig_width * ratio_height), target_height) - attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0]))) + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), + int(mask_size * target_aspect_ratio[0]))) if padding_width >= 14: - attention_mask[:, -math.floor(padding_width/14):] = 0 + attention_mask[:, -math.floor(padding_width / 14):] = 0 if padding_height >= 14: - attention_mask[-math.floor(padding_height/14):,:] = 0 - assert attention_mask.sum() > 0, f'attention mask is empty {attention_mask}' + attention_mask[-math.floor(padding_height / 14):, :] = 0 + assert attention_mask.sum( + ) > 0, f'attention mask is empty {attention_mask}' - if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + if min(new_size[1], target_height) < 10 or min(new_size[0], + target_width) < 10: raise ValueError(f'the aspect ratio is very extreme {new_size}') - image = T.functional.resize(image, [new_size[1], new_size[0]],) + image = T.functional.resize( + image, + [new_size[1], new_size[0]], + ) - resized_img = T.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255]) + resized_img = T.functional.pad(image, + [0, 0, padding_width, padding_height], + fill=[255, 255, 255]) return resized_img, attention_mask + def pad_to_max_num_crops(images, max_crops=5): """ images: B x 3 x H x W, B<=max_crops """ B, _, H, W = images.shape if max_crops > B: - pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + pad = torch.zeros(max_crops - B, + 3, + H, + W, + dtype=images.dtype, + device=images.device) images = torch.cat([images, pad], dim=0) return images + def pad_mask_to_max_num_crops(masks, max_crops=5): B, H, W = masks.shape if max_crops > B: - pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + pad = torch.ones(max_crops - B, + H, + W, + dtype=masks.dtype, + device=masks.device) masks = torch.cat([masks, pad], dim=0) return masks + def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): # Basic settings. img_processor = T.Compose([ T.ToTensor(), - T.Normalize( - (0.5, 0.5, 0.5), - (0.5, 0.5, 0.5) - ), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) # Dynamic HD base_resolution = vit_resolution @@ -238,43 +250,78 @@ def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): mask_resolution = base_resolution // vit_patch_size elems, image_attention_masks = [], [] for im in images: - elem, attention_mask = dynamic_preprocess(im, max_num=dynamic_hd_size, image_size=base_resolution, mask_size=mask_resolution) + elem, attention_mask = dynamic_preprocess(im, + max_num=dynamic_hd_size, + image_size=base_resolution, + mask_size=mask_resolution) elems.append(elem) image_attention_masks.append(attention_mask) hd_images = [img_processor(im) for im in elems] - global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images] + global_image = [ + torch.nn.functional.interpolate( + im.unsqueeze(0).float(), + size=(base_resolution, base_resolution), + mode='bicubic', + ).to(im.dtype) for im in hd_images + ] shapes = [[im.size(1), im.size(2)] for im in hd_images] - mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] - global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images] - hd_images_reshape = [im.reshape(1, 3, - h//base_resolution, - base_resolution, - w//base_resolution, - base_resolution - ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)] - attention_masks_reshape = [mask.reshape(1, - h//mask_resolution, - mask_resolution, - w//mask_resolution, - mask_resolution - ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)] + mask_shapes = [[mask.size(0), mask.size(1)] + for mask in image_attention_masks] + global_attention_mask = [ + torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images + ] + hd_images_reshape = [ + im.reshape(1, 3, h // base_resolution, base_resolution, + w // base_resolution, base_resolution).permute( + 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, + base_resolution).contiguous() + for im, (h, w) in zip(hd_images, shapes) + ] + attention_masks_reshape = [ + mask.reshape(1, h // mask_resolution, mask_resolution, + w // mask_resolution, mask_resolution).permute( + 0, 1, 3, 2, 4).reshape(-1, mask_resolution, + mask_resolution).contiguous() + for mask, (h, w) in zip(image_attention_masks, mask_shapes) + ] # NOTE token compression is hard coded here, and odd numbers seems to fail - downsample_attention_masks = [mask[:,0::2,0::2].reshape(1, - h//mask_resolution, - w//mask_resolution, - mask_resolution//2+mask_resolution%2, - mask_resolution//2+mask_resolution%2 - ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)] - downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks] + downsample_attention_masks = [ + mask[:, 0::2, + 0::2].reshape(1, h // mask_resolution, w // mask_resolution, + mask_resolution // 2 + mask_resolution % 2, + mask_resolution // 2 + mask_resolution % 2).permute( + 0, 1, 3, 2, 4) + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) + ] + downsample_attention_masks = [ + mask.reshape(mask.size(1) * mask.size(2), + mask.size(3) * mask.size(4)) + for mask in downsample_attention_masks + ] # NOTE hard coded number of tokens - num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks] + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 + for mask in downsample_attention_masks + ] - hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)] - hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)] + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) + for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], + dim=0) for _global_mask, _mask in zip( + global_attention_mask, attention_masks_reshape) + ] max_crops = max([img.size(0) for img in hd_images_reshape]) - image_transformed = [pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = [ + pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape + ] image_transformed = torch.stack(image_transformed, dim=0) - mask_transformed = [pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = [ + pad_mask_to_max_num_crops(mask, max_crops) \ + for mask in hd_masks_reshape + ] mask_transformed = torch.stack(mask_transformed, dim=0) returned_input_image_embeds = image_transformed @@ -302,9 +349,11 @@ def __init__(self, super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): - embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + embd_drop = config.embd_pdrop if hasattr( + config, 'embd_pdrop') else config.embed_pdrop self.drop = nn.Dropout(embd_drop) else: self.drop = None @@ -312,12 +361,14 @@ def __init__(self, # layer_idx to output the img features if isinstance(config.img_processor, dict): self.layer_idx = config.img_processor.get('layer_idx', -2) - self.type_feature = config.img_processor.get('type_feature', 'patch') + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') else: self.layer_idx = -2 self.type_feature = 'patch' - self.img_processor = get_siglip_vision_model(_flash_attn_2_enabled=True) + self.img_processor = get_siglip_vision_model( + _flash_attn_2_enabled=True) pe_weight = self.img_processor.embeddings.position_embedding.weight L, D = pe_weight.size() @@ -328,15 +379,13 @@ def __init__(self, H += 1 image_dim_out = D # ((448/14)//2)**2 - self.num_img_tokens = (H//2)**2 + self.num_img_tokens = (H // 2)**2 self.base_feat_height_target = H - self.image_dim_out = image_dim_out self.img_sizes = None self.image_attention_mask = None - # global_gn and sub_gn for hd transform, serves as line separator self.use_hd_transform = True self.with_learnable_separator = True @@ -351,21 +400,33 @@ def __init__(self, self.base_feat_height_target = self.base_feat_height_target // 2 # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' - assert self.use_hd_transform, 'learnable separator is only for hd transform' + assert self.use_hd_transform == self.with_learnable_separator, \ + 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform, \ + 'learnable separator is only for hd transform' # 1024 * 4, merge spatial to channel dimension - self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) - self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + self.glb_GN = nn.Parameter( + torch.zeros([ + 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 + ])) + self.sub_GN = nn.Parameter( + torch.zeros([ + 1, 1, 1, + self.image_dim_out * self.base_feat_height_reduction**2 + ])) dim_projection = hidden_size depth = 2 - layers = [nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection)] + layers = [ + nn.Linear(image_dim_out * self.base_feat_height_reduction**2, + dim_projection) + ] for _ in range(1, depth): - layers.extend([nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) - self.vocab_size = config.vocab_size self.img_features = None @@ -377,20 +438,23 @@ def get_img_features(self, LAYER_IDX = self.layer_idx TYPE_FEATURE = self.type_feature - img_processor_output = self.img_processor(img_embeds, - output_hidden_states=True, - patch_attention_mask=attention_mask) + img_processor_output = self.img_processor( + img_embeds, + output_hidden_states=True, + patch_attention_mask=attention_mask) img_feature = img_processor_output.hidden_states[LAYER_IDX] if TYPE_FEATURE == "patch": patch_feature = img_feature use_token_compression = self.image_token_compression is not None - use_padding = getattr(self, 'img_processor_padding', None) is not None + use_padding = getattr(self, 'img_processor_padding', + None) is not None if use_token_compression or use_padding: # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + patch_feature = patch_feature.view(-1, width, width, + patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) @@ -401,14 +465,16 @@ def get_img_features(self, # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) - patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + patch_feature = patch_feature.view( + -1, + patch_feature.size(1) * patch_feature.size(2), + patch_feature.size(-1)) return patch_feature raise NotImplementedError - def forward(self, - pixel_values: torch.FloatTensor, + def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, image_attention_mask: torch.Tensor) -> torch.FloatTensor: """ @@ -434,22 +500,30 @@ def forward(self, img_sizes = image_sizes num_images, num_crops, c, h, w = pixel_values.shape - # assert num_images == 1, "Currently only support single image" # TODO debug multi-image bs = num_images pixel_values = pixel_values.flatten(0, 1) - img_features = self.get_img_features(pixel_values, - image_attention_mask.type(torch.BoolTensor).flatten(0,1).to(target_device)) + img_features = self.get_img_features( + pixel_values, + image_attention_mask.type(torch.BoolTensor).flatten( + 0, 1).to(target_device)) base_feat_height_target = self.base_feat_height_target base_resolution = self.crop_size base_feat_height_reduction = self.base_feat_height_reduction - base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) - assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform' + base_feat_height = base_feat_width = int(np.sqrt( + img_features.shape[1])) + assert base_feat_height == base_feat_height_target \ + and base_feat_width == base_feat_height_target, \ + f'base_feat_height: {base_feat_height},"\ + f" base_feat_width: {base_feat_width}, "\ + f"expect {base_feat_height_target} features for hd transform' # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + img_features = img_features.view(bs, -1, + base_feat_height * base_feat_width, + self.image_dim_out) C = self.image_dim_out H = base_feat_height @@ -468,11 +542,22 @@ def forward(self, global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, H//base_feat_height_reduction, 1, 1) + glb_img = global_img_feature.reshape(1, H, H, C).reshape( + 1, H // base_feat_height_reduction, base_feat_height_reduction, + H // base_feat_height_reduction, base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + 1, H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C).contiguous() + temp_glb_GN = self.sub_GN.repeat(1, + H // base_feat_height_reduction, + 1, 1) # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) # (max_num_crops-1) x (12x12) x C sub_img = img_features[_bs, 1:] @@ -480,39 +565,81 @@ def forward(self, # get rid of padding sub_img sub_img = sub_img[:B_] - # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() - sub_img = sub_img.reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction, -1).permute(0,1,3,2,4,5).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C) - - if image_attention_mask is not None and len(image_attention_mask) > 0: - reshaped_image_attention_mask = image_attention_mask[_bs,1:B_+1,0::2,0::2].reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction).permute(0,1,3,2,4).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction) - useful_height = int(reshaped_image_attention_mask[0,:,0].sum().item()) - useful_width = int(reshaped_image_attention_mask[0,0,:].sum().item()) - sub_img = sub_img[:,:useful_height, :useful_width] + # (num_crops, 12, 2, 12, 2, 1024) -> + # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = sub_img.reshape(B_, H, H, C).reshape( + B_, H // base_feat_height_reduction, + base_feat_height_reduction, H // base_feat_height_reduction, + base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + B_, -1, base_feat_height_reduction * + base_feat_height_reduction * C).contiguous() + sub_img = sub_img.reshape( + 1, h, w, base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1).permute(0, 1, 3, 2, 4, 5).reshape( + 1, h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C) + + if image_attention_mask is not None and len( + image_attention_mask) > 0: + reshaped_image_attention_mask = image_attention_mask[ + _bs, 1:B_ + 1, 0::2, 0::2].reshape( + 1, h, w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction).permute( + 0, 1, 3, 2, 4).reshape( + 1, h * base_feat_height // + base_feat_height_reduction, w * + base_feat_width // base_feat_height_reduction) + useful_height = int( + reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int( + reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) - temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction + temp_len = int( + image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( + )) + (useful_height + + 1) + base_feat_height // base_feat_height_reduction else: - temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1) - temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction) - - sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + temp_sub_GN = self.sub_GN.repeat( + 1, h * base_feat_height // base_feat_height_reduction, 1, + 1) + temp_len = int((h * w + 1) * self.num_img_tokens + 1 + + (h + 1) * base_feat_height // + base_feat_height_reduction) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) # (1, num_img_tokens, 1024*4) # glb + sub if self.hd_transform_order == 'glb_sub': - output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) elif self.hd_transform_order == 'sub_glb': - output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) else: - raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented') + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}, "\ + "not implemented') #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) - assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + "{output_imgs[-1].shape[1]}' + output_len.append(temp_len) img_set_tensor = [] for _output_img in output_imgs: - img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) img_set_tensor.append(img_feature_proj) return img_set_tensor @@ -555,10 +682,9 @@ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): fmax = sample_rate / 2 if fmin is None: fmin = 0 - assert fmin >= 0, "fmin cannot be negtive" - assert ( - fmin < fmax <= sample_rate / 2 - ), "fmax must be between (fmin, samplerate / 2]" + assert fmin >= 0, "fmin cannot be negative" + assert (fmin < fmax <= + sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" def mel(f): return 1127.0 * np.log(1.0 + f / 700.0) @@ -575,7 +701,7 @@ def f2bin(f): khi = max(khi, klo) - # Spec 2: SpeechLib uses trianges in Mel space + # Spec 2: SpeechLib uses triangles in Mel space mlo = mel(fmin) mhi = mel(fmax) m_centers = np.linspace(mlo, mhi, n_mels + 2) @@ -595,6 +721,7 @@ def f2bin(f): class LogFbankProcessor: + def __init__(self): self._eightk_method = "fillzero" @@ -640,8 +767,7 @@ def extract_spectrogram(self, wav, fs): elif fs != 16000: # Input audio is not a supported sample rate. raise RuntimeError( - f"Input data using an unsupported sample rate: {fs}" - ) + f"Input data using an unsupported sample rate: {fs}") preemphasis = 0.97 @@ -660,11 +786,11 @@ def extract_spectrogram(self, wav, fs): n_batch = (wav.shape[0] - win_length) // hop_length + 1 # Here we don't use stride_tricks since the input array may not satisfy # memory layout requirement and we need writeable output - # Here we only use list of views before copy to desination + # Here we only use list of views before copy to destination # so it is more efficient than broadcasting y_frames = np.array( [ - wav[_stride : _stride + win_length] + wav[_stride:_stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length) ], dtype=np.float32, @@ -675,18 +801,16 @@ def extract_spectrogram(self, wav, fs): y_frames_prev[:, 0] = y_frames_prev[:, 1] y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 - S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype( - np.complex64 - ) + S = np.fft.rfft(fft_window * y_frames, n=n_fft, + axis=1).astype(np.complex64) if fs == 8000: # Need to pad the output to look like 16 kHz data but with zeros in # the 4 to 8 kHz bins. frames, bins = S.shape padarray = np.zeros((frames, bins)) - S = np.concatenate( - (S[:, 0:-1], padarray), axis=1 - ) # Nyquist bin gets set to zero + S = np.concatenate((S[:, 0:-1], padarray), + axis=1) # Nyquist bin gets set to zero spec = np.abs(S).astype(np.float32) return spec @@ -713,54 +837,58 @@ def extract_features(self, wav, fs): @lru_cache def audio_feature_extractor() -> LogFbankProcessor: # Creates an instance of the audio processor, needed to extract the - # the audio featues from the sound file + # the audio features from the sound file # LRU cache ensures that we only make one copy return LogFbankProcessor() -def _compute_num_image_tokens( - image, dynamic_hd_size, vit_image_size, vit_patch_size, token_compression_factor -): +def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor): """ - compute the number of tokens an image is expected to take up considering the image encoder - architecture and exclude output features containing only padding pixels + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ - assert vit_image_size % vit_patch_size == 0, "vit_image_size must be divisible by vit_patch_size" - assert vit_image_size // vit_patch_size % token_compression_factor == 0, "vit_image_size // vit_patch_size must be divisible by token_compression_factor" + assert vit_image_size % vit_patch_size == 0, \ + "vit_image_size must be divisible by vit_patch_size" + assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ + "vit_image_size // vit_patch_size must be divisible by "\ + "token_compression_factor" target_aspect_ratio, target_height, target_width = ( - _find_target_aspect_ratio( - image, vit_image_size, dynamic_hd_size, min_num=1 - ) - ) - assert target_aspect_ratio[0] * vit_image_size == target_width, f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - assert target_aspect_ratio[1] * vit_image_size == target_height, f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - assert ( - target_height % vit_image_size == 0 - and target_width % vit_image_size == 0 - ) - - padding_height, padding_width = _get_padding_size( - image, target_height, target_width - ) - assert padding_width == 0 or padding_height == 0, "padding_width or padding_height must be 0" + _find_target_aspect_ratio(image, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[ + 0] * vit_image_size == target_width, \ + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + assert target_aspect_ratio[ + 1] * vit_image_size == target_height, \ + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size - ) + padding_width / vit_patch_size) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size - ) + padding_height / vit_patch_size) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch @@ -777,17 +905,13 @@ def _compute_num_image_tokens( num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 + num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 num_sep_tokens = 1 - num_global_image_newline_tokens = vit_feature_size // token_compression_factor - - return ( - num_global_image_tokens - + num_sep_tokens - + num_hd_patch_tokens - + num_hd_newline_tokens - + num_global_image_newline_tokens - ) + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + + num_hd_newline_tokens + num_global_image_newline_tokens) def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: @@ -843,12 +967,11 @@ def _get_audio_embed_sizes(audios, ctx: InputContext): """ audio_embed_sizes = [] for audio in audios: - audio_data, sf = audio - audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) - audio_embed_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames - ) - audio_embed_sizes.append(audio_embed_size) + audio_data, sf = audio + audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) + audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) + audio_embed_sizes.append(audio_embed_size) return audio_embed_sizes @@ -865,7 +988,8 @@ def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): prompt_str (str): The prompt string. Returns: - Dict[str, List[int]]: Mapping of audio placeholder tokens to audio placeholder token ids. + Dict[str, List[int]]: Mapping of audio placeholder tokens to audio + placeholder token ids. """ if len(audios) == 0: @@ -874,12 +998,15 @@ def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) audio_ids = [int(audio_id) for audio_id in audio_ids] - assert len(audio_ids) == len(audio_embed_sizes), "Number of audio tokens and audio features do not match" - assert tuple(audio_ids) == tuple( - range(1, len(audio_ids) + 1) - ), "Audio ids are not in order!" + assert len(audio_ids) == len( + audio_embed_sizes + ), "Number of audio tokens and audio features do not match" + assert tuple(audio_ids) == tuple(range(1, + len(audio_ids) + + 1)), "Audio ids are not in order!" audio_id_to_input_ids = { - f"<|audio_{audio_id}|>": [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + f"<|audio_{audio_id}|>": + [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) } @@ -898,13 +1025,8 @@ def _count_image_tokens(images, ctx: InputContext): token_compression_factor = prepro_config['token_compression_factor'] image_token_counts = [ - _compute_num_image_tokens( - image, - dynamic_hd_size, - vit_image_size, - vit_patch_size, - token_compression_factor - ) + _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor) for image in images ] return image_token_counts @@ -916,11 +1038,15 @@ def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) image_ids = [int(image_id) for image_id in image_ids] - assert len(image_ids) == len(set(image_ids)), "Duplicate image tokens in prompt" - assert len(images) == len(image_ids), "Number of images and image tokens in prompt do not match" + assert len(image_ids) == len( + set(image_ids)), "Duplicate image tokens in prompt" + assert len(images) == len( + image_ids), "Number of images and image tokens in prompt do not match" # NOTE the following assertion is not strictly necessary - assert tuple(image_ids) == tuple(range(1, len(image_ids) + 1)), "Image ids are not in order" + assert tuple(image_ids) == tuple(range(1, + len(image_ids) + + 1)), "Image ids are not in order" image_token_counts = _count_image_tokens(images, ctx) image_id_to_input_ids = { @@ -930,9 +1056,8 @@ def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): return image_id_to_input_ids -def input_processor_for_phi4mm( - ctx: InputContext, inputs: DecoderOnlyInputs -) -> TokenInputs: +def input_processor_for_phi4mm(ctx: InputContext, + inputs: DecoderOnlyInputs) -> TokenInputs: """ Implements the input processor, which transforms the input prompt ids to include the audio placeholder token. This will become the `input_ids` @@ -940,23 +1065,25 @@ def input_processor_for_phi4mm( Args: ctx (InputContext): Input context. - inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) to process. + inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) + to process. Returns: TokenInputs: Processed inputs """ multi_modal_data = inputs.get("multi_modal_data") - if (multi_modal_data is None - or ("audio" not in multi_modal_data - and "image" not in multi_modal_data)): + if (multi_modal_data is None or + ("audio" not in multi_modal_data and "image" not in multi_modal_data)): # pure text input, so no need to do pre-processing return inputs prompt_str = inputs.get("prompt") prompt_token_ids = inputs.get("prompt_token_ids") - # for offline_inference, we will get str input and we parse MM special tokens from it + # for offline_inference, we will get str input and we parse MM special + # tokens from it # (ignore prompt_token_ids) - # for OAI server, we will get prompt_token_ids, where MM special tokens are already parsed + # for OAI server, we will get prompt_token_ids, where MM special tokens + # are already parsed if 'audio' in multi_modal_data: audios = multi_modal_data["audio"] @@ -964,7 +1091,8 @@ def input_processor_for_phi4mm( if not isinstance(audios, list): audios = [audios] if prompt_str is not None: - audio_id_to_input_ids = _get_audio_id_to_input_ids(audios, ctx, prompt_str=prompt_str) + audio_id_to_input_ids = _get_audio_id_to_input_ids( + audios, ctx, prompt_str=prompt_str) audio_embed_sizes = [] elif prompt_token_ids is not None: audio_id_to_input_ids = {} @@ -979,7 +1107,8 @@ def input_processor_for_phi4mm( if not isinstance(images, list): images = [images] if prompt_str is not None: - image_id_to_input_ids = _get_image_id_to_input_ids(images, prompt_str, ctx) + image_id_to_input_ids = _get_image_id_to_input_ids( + images, prompt_str, ctx) image_token_counts = [] elif prompt_token_ids is not None: image_id_to_input_ids = {} @@ -988,15 +1117,19 @@ def input_processor_for_phi4mm( image_id_to_input_ids = {} image_token_counts = [] - # Handle the case where the prompt is a string and we need to manually tokenize it. - # In this case, the `audio_id_to_input_ids` dict will be mapping from an audio placeholder - # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the given audio length. + # Handle the case where the prompt is a string and we need to manually + # tokenize it. + # In this case, the `audio_id_to_input_ids` dict will be mapping from + # an audio placeholder + # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the + # given audio length. if prompt_str: pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" prompt_chunk_strings = re.split(pattern, prompt_str) prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] - # Create the new input_ids with the placholder image and audio tokens inserted + # Create the new input_ids with the placeholder image and audio + # tokens inserted tokenizer = cached_tokenizer_from_config(ctx.model_config) input_ids = [] has_imag, has_audio, has_user_text_input = False, False, False @@ -1016,15 +1149,18 @@ def input_processor_for_phi4mm( break input_ids.extend(curr_token_ids) if has_audio and has_imag and has_user_text_input: - raise ValueError("Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") # Handle the case where the prompt is already tokenized else: - assert prompt_token_ids is not None, "If string prompt isn't provided, prompt_token_ids must be" + assert prompt_token_ids is not None, \ + "If string prompt isn't provided, prompt_token_ids must be" i = 0 input_ids = prompt_token_ids - img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 # only needed for later assertion + # only needed for later assertion + img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 image_token_count_iter = iter(image_token_counts) audio_embed_size_iter = iter(audio_embed_sizes) while i < len(input_ids): @@ -1036,7 +1172,8 @@ def input_processor_for_phi4mm( token_count = next(image_token_count_iter) img_cnt += 1 else: - user_text_input_cnt += 1 if token_id not in NON_USER_INPUT_TOKENS else 0 + user_text_input_cnt += 1 if token_id not in \ + NON_USER_INPUT_TOKENS else 0 i += 1 continue tokens = [token_id] * token_count @@ -1044,23 +1181,18 @@ def input_processor_for_phi4mm( i += token_count if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: - raise ValueError("Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") # If the below assertion fails, it might be that input pure-text # messages contain image/audio special tokens literally # (<|endoftext10|>, <|endoftext11|>). - assert ( - img_cnt == len(image_token_counts) - ), ( + assert (img_cnt == len(image_token_counts)), ( f"Number of image tokens in prompt_token_ids ({img_cnt}) " - f"does not match number of images ({len(image_token_counts)})" - ) - assert ( - audio_cnt == len(audio_embed_sizes) - ), ( + f"does not match number of images ({len(image_token_counts)})") + assert (audio_cnt == len(audio_embed_sizes)), ( f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " - f"does not match number of audios ({len(audio_embed_sizes)})" - ) + f"does not match number of audios ({len(audio_embed_sizes)})") # NOTE: Create a defensive copy of the original inputs return token_inputs( @@ -1072,9 +1204,11 @@ def input_processor_for_phi4mm( def _compute_audio_embed_size(hf_config, audio_frames): """ - Compute the audio embedding size based on the audio frames and compression rate. + Compute the audio embedding size based on the audio frames and + compression rate. """ - compression_rate = hf_config.embd_layer['audio_embd_layer']['compression_rate'] + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] # NOTE: this is a hard-coded value but might be configurable in the future qformer_compression_rate = 1 integer = audio_frames // compression_rate @@ -1092,6 +1226,7 @@ def _compute_audio_embed_size(hf_config, audio_frames): def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: return 10000 + def dummy_audio_for_phi4mm(audio_count: int) -> dict: """ Create dummy audio data for the Phi4MM model, which is used for profiling. @@ -1102,7 +1237,7 @@ def dummy_audio_for_phi4mm(audio_count: int) -> dict: Returns: dict: Dummy audio data. """ - dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE,), 0.0) + dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count @@ -1111,15 +1246,15 @@ def dummy_image_for_phi4mm(width: int, height: int): return image -def dummy_data_for_phi4mm( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> DummyData: +def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: """ - Create dummy sequence (input_ids) and audio data for the Phi4MM model, which is used for - profiling. + Create dummy sequence (input_ids) and audio data for the Phi4MM model, + which is used for profiling. - In this case, the sequence data is a bunch of 0s with a number of audio tokens that correspond - to the audio embed size of the _AUDIO_MAX_SOUNDFILE_SIZE. + In this case, the sequence data is a bunch of 0s with a number of audio + tokens that correspond to the audio embed size of the + _AUDIO_MAX_SOUNDFILE_SIZE. Args: ctx (InputContext): Input context. @@ -1130,12 +1265,10 @@ def dummy_data_for_phi4mm( Tuple: Dummy sequence data and dummy audio data. """ audio_count = mm_counts["audio"] - audio_frames, _ = compute_logfbank_output_size( - _AUDIO_MAX_SOUNDFILE_SIZE, DUMMY_SAMPLING_FREQUENCY - ) - audio_feature_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames - ) + audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, + DUMMY_SAMPLING_FREQUENCY) + audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) image_count = mm_counts["image"] dummy_image = get_max_dummy_image(ctx) @@ -1144,9 +1277,11 @@ def dummy_data_for_phi4mm( if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: raise RuntimeError( - f"Phi4MM cannot process {audio_count} audios and {image_count} images in a prompt," - f"please increase max_model_len to be at larger than {audio_feature_size * audio_count + total_image_tokens}" - " or reduce audio/image limit by --limit-mm-per-prompt.") + f"Phi4MM cannot process {audio_count} audios and {image_count}" + f"images in a prompt, please increase max_model_len to be at" + f" larger than " + f"{audio_feature_size * audio_count + total_image_tokens}" + " or reduce audio/image limit by --limit-mm-per-prompt.") if audio_feature_size * audio_count > total_image_tokens: seq_data = SequenceData.from_prompt_token_counts( @@ -1167,17 +1302,20 @@ def dummy_data_for_phi4mm( return DummyData(seq_data, mm_data) -def input_mapper_for_phi4mm_audio(ctx: InputContext, data: object) -> MultiModalInputs: +def input_mapper_for_phi4mm_audio(ctx: InputContext, + data: object) -> MultiModalInputs: """ - This function is used to create the MultiModalInputs for the Phi4MM (audio) model. - Specifically, for audio, we extract the audio features from the sound file and create - pairs of audio features and audio embed lengths (the latter of which is used to repeat - the audio placeholder token in the input prompt IDs). + This function is used to create the MultiModalInputs for the Phi4MM + (audio) model. + Specifically, for audio, we extract the audio features from the sound + file and create pairs of audio features and audio embed lengths (the + latter of which is used to repeat the audio placeholder token in the + input prompt IDs). These pairs are used, downstream, in `_audio_features_to_embeddings` (via `_process_audio_input`). - Note that the incoming audio data (each entry in `data`) is a tuple of the audio data - and the sampling frequency (e.g. from soundfile.read). + Note that the incoming audio data (each entry in `data`) is a tuple of + the audio data and the sampling frequency (e.g. from soundfile.read). Args: ctx (InputContext): Input context. @@ -1196,21 +1334,16 @@ def input_mapper_for_phi4mm_audio(ctx: InputContext, data: object) -> MultiModal for audio_input in data: if not isinstance(audio_input, tuple): raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}" - ) + f"Unsupported data type: {type(audio_input)}") audio, sf = audio_input feature_extractor = audio_feature_extractor() single_audio_features = feature_extractor.extract_features(audio, sf) - feat_stride = ( - 1 - if not hasattr(feature_extractor, "stride") - else feature_extractor.stride - ) + feat_stride = (1 if not hasattr(feature_extractor, "stride") else + feature_extractor.stride) audio_frames = len(single_audio_features) * feat_stride single_audio_embed_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames - ) + ctx.get_hf_config(), audio_frames) single_audio_feature_audio_len_pair = ( single_audio_features, [single_audio_embed_size], @@ -1234,14 +1367,17 @@ def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): vit_image_size = prepro_config['vit_image_size'] vit_patch_size = prepro_config['vit_patch_size'] - image_input_dict = preprocess( - data, dynamic_hd_size, vit_image_size, vit_patch_size - ) + image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, + vit_patch_size) return MultiModalInputs({ - "pixel_values": image_input_dict["pixel_values"], - "image_sizes": image_input_dict["image_sizes"], - "image_attention_mask": image_input_dict["image_attention_mask"], - "num_img_tokens": image_input_dict["num_img_tokens"], + "pixel_values": + image_input_dict["pixel_values"], + "image_sizes": + image_input_dict["image_sizes"], + "image_attention_mask": + image_input_dict["image_attention_mask"], + "num_img_tokens": + image_input_dict["num_img_tokens"], }) @@ -1250,7 +1386,9 @@ def cat_with_pad(tensors, dim, padding_value=0): cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() - assert all(t.dim() == ndim for t in tensors[1:]), "All tensors must have the same number of dimensions" + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) @@ -1269,14 +1407,14 @@ def cat_with_pad(tensors, dim, padding_value=0): return output -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_phi4mm_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", input_mapper_for_phi4mm_image) +@MULTIMODAL_REGISTRY.register_input_mapper("audio", + input_mapper_for_phi4mm_audio) +@MULTIMODAL_REGISTRY.register_input_mapper("image", + input_mapper_for_phi4mm_image) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi4mm_audio_tokens -) + "audio", get_max_phi4mm_audio_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi4mm_image_tokens -) + "image", get_max_phi4mm_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): @@ -1292,7 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): "gate_up_proj", ], } - # QKVParallelLinear, RowParallelLinear, MergedColumnParallelLinear, RowParallelLinear supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj" ] @@ -1314,8 +1451,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. - assert get_tensor_model_parallel_world_size() == 1, "tensor parallel is not supported" - assert get_pp_group().world_size == 1, "pipeline parallel is not supported" + assert get_tensor_model_parallel_world_size( + ) == 1, "tensor parallel is not supported" + assert get_pp_group( + ).world_size == 1, "pipeline parallel is not supported" self.vision_encoder = Phi4MMImageEncoder( config, @@ -1323,12 +1462,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix="model.vision_embed_tokens", model_dir=config._name_or_path) - if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { - "embedding_cls": config.embd_layer["audio_embd_layer"][ - "embedding_cls" - ], + "embedding_cls": + config.embd_layer["audio_embd_layer"]["embedding_cls"], **config.embd_layer["audio_embd_layer"], } else: @@ -1337,7 +1474,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - self.model = LlamaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -1350,17 +1488,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + config.vocab_size, logit_scale) self.sampler = Sampler() def _audio_features_to_embeddings( @@ -1371,40 +1506,45 @@ def _audio_features_to_embeddings( audio_projection_mode: str, ) -> torch.Tensor: """ - Convert audio features to embeddings, which are used as input to the model (via - `inputs_embeds`). + Convert audio features to embeddings, which are used as input to the + model (via `inputs_embeds`). Args: input_ids (torch.Tensor): Input IDs (the prompt in this case). - input_features (list[torch.Tensor]): Input features (the audio embeddings). - audio_input_sizes (list[torch.Tensor]): Audio input sizes (the audio embed lengths to use for - padding the audio placeholder token in the input prompt IDs). + input_features (list[torch.Tensor]): Input features (the audio + embeddings). + audio_input_sizes (list[torch.Tensor]): Audio input sizes (the + audio embed lengths to use for padding the audio placeholder token + in the input prompt IDs). """ - # The audio projection can either be a single linear or Sequential, so handle - # both cases - if isinstance(self.embed_tokens_extend.audio_projection, nn.Sequential): + # The audio projection can either be a single linear or Sequential, + # so handle both cases + if isinstance(self.embed_tokens_extend.audio_projection, + nn.Sequential): target_dtype = self.embed_tokens_extend.audio_projection[ - 0 - ].bias.dtype + 0].bias.dtype else: target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype audio_input = [ input.unsqueeze(0).to(target_dtype) for input in input_features ] - kwargs = {"wte": self.model.embed_tokens, 'audio_projection_mode': audio_projection_mode} - audio_embeddings = self.embed_tokens_extend( - input_ids, audio_input, audio_input_sizes, **kwargs - ) + kwargs = { + "wte": self.model.embed_tokens, + 'audio_projection_mode': audio_projection_mode + } + audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, + audio_input_sizes, + **kwargs) audio_embeddings = audio_embeddings.to(target_dtype) return audio_embeddings def _parse_and_validate_audio_input( - self, **kwargs: object - ) -> Optional[Phi4MMAudioInputs]: + self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: """ - Parse and validate the audio input to the model. This handles both audio features and - audio embeddings, but only the former is used for now. + Parse and validate the audio input to the model. This handles both + audio features and audio embeddings, but only the former is used for + now. Args: kwargs (object): Keyword arguments. @@ -1420,39 +1560,33 @@ def _parse_and_validate_audio_input( if audio_features is not None: if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio features. " - f"Got type: {type(audio_features)}" - ) + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") - return Phi4MMAudioFeatureInputs( - type="audio_features", data=audio_features - ) + return Phi4MMAudioFeatureInputs(type="audio_features", + data=audio_features) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}" - ) + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") - return Phi4MMAudioEmbeddingInputs( - type="audio_embeds", data=audio_embeds - ) + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) raise AssertionError("This line should be unreachable.") - def _process_audio_input( - self, input_ids: torch.Tensor, audio_input: Phi4MMAudioInputs, audio_projection_mode: str - ) -> NestedTensors: + def _process_audio_input(self, input_ids: torch.Tensor, + audio_input: Phi4MMAudioInputs, + audio_projection_mode: str) -> NestedTensors: """ - Create the audio embeddings from the audio input, where the audio input is pairs of - audio features and audio embed lengths. The audio input is created by - `input_mapper_for_phi4mm_audio`. + Create the audio embeddings from the audio input, where the audio input + is pairs of audio features and audio embed lengths. The audio input is + created by `input_mapper_for_phi4mm_audio`. Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case, before the audio token - replication). + input_ids (torch.Tensor): Input IDs (the prompt in this case, + before the audio token replication). audio_input (Phi4MMAudioInputs): Audio input. Returns: @@ -1475,9 +1609,8 @@ def _process_audio_input( audio_projection_mode, ).squeeze(0) - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Optional[Dict]: + def _parse_and_validate_image_input(self, + **kwargs: object) -> Optional[Dict]: pixel_values: Optional[Dict] = kwargs.get("pixel_values") if pixel_values is None: return None @@ -1485,18 +1618,22 @@ def _parse_and_validate_image_input( image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") - assert image_sizes is not None and image_attention_mask is not None and num_img_tokens is not None, "Missing image inputs" + assert image_sizes is not None and image_attention_mask is not None\ + and num_img_tokens is not None, "Missing image inputs" if isinstance(pixel_values, list): assert pixel_values[0].dim() == 5, "Incorrect image inputs" - # list len is batch_size - # each tensor has dimension: num_img_per_example, num_hd_patches, channels, height, width - # need to pad along num_hd_patches - # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w + # list len is batch_size. + # each tensor has dimension: num_img_per_example, num_hd_patches, + # channels, height, width. + # need to pad along num_hd_patches. + # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. pixel_values = cat_with_pad(pixel_values, dim=0) elif isinstance(pixel_values, torch.Tensor): - # dimension: batch_size, num_img_per_example, num_hd_patches, channels, height, width - # we flatten first 2 dims to make it a single large batch for SigLIP Encoder + # dimension: batch_size, num_img_per_example, num_hd_patches, + # channels, height, width. + # we flatten first 2 dims to make it a single large batch for + # SigLIP Encoder. assert pixel_values.dim() == 6, "Incorrect image inputs" pixel_values = pixel_values.flatten(0, 1) else: @@ -1517,7 +1654,10 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect image_attention_mask inputs") if isinstance(num_img_tokens, list): - num_img_tokens = [n for num_tensor in num_img_tokens for n in num_tensor.tolist()] + num_img_tokens = [ + n for num_tensor in num_img_tokens + for n in num_tensor.tolist() + ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: @@ -1536,12 +1676,15 @@ def merge_image_features_to_inputs_embeds( inputs_embeds: torch.Tensor, image_set_tensors: List[torch.Tensor], ): - position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=True) + position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=True) - assert all([t.shape[0] == 1 for t in image_set_tensors]), 'img_set_tensor should have shape (1, N_tokens, C)' + assert all([t.shape[0] == 1 for t in image_set_tensors + ]), 'img_set_tensor should have shape (1, N_tokens, C)' # Shape: (merged_N_tokens, C) image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) - image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(inputs_embeds.device) + image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( + inputs_embeds.device) merged_embeds = inputs_embeds.index_put( indices=position_tuple, values=image_set_tensor, @@ -1549,36 +1692,37 @@ def merge_image_features_to_inputs_embeds( ) return merged_embeds - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: weights = {name: weight for name, weight in weights} adjusted_weights = {} for name, weight in weights.items(): - # NOTE vision-speech tasks use a seperate projection layer - audio_proj_4v = "model.embed_tokens_extend.audio_embed.audio_projection.vision" + # NOTE vision-speech tasks use a separate projection layer + audio_proj_4v = \ + "model.embed_tokens_extend.audio_embed.audio_projection.vision" if name.startswith(audio_proj_4v): - name = name.replace(audio_proj_4v, "embed_tokens_extend.audio_projection_for_vision") - - name = ( - name.replace( - "model.embed_tokens_extend.audio_embed.audio_projection.speech.", - "embed_tokens_extend.audio_projection.", - ) - .replace( - "model.embed_tokens_extend.audio_embed.", - "embed_tokens_extend.", - ) - .replace("model.embed_tokens_extend.image_embed.", "vision_encoder.") - ) - # NOTE: this is deal with LoRA injection, where `base_layer` remains as the original - # layer in the model + name = name.replace( + audio_proj_4v, + "embed_tokens_extend.audio_projection_for_vision") + + name = (name.replace( + "model.embed_tokens_extend.audio_embed."\ + "audio_projection.speech.", + "embed_tokens_extend.audio_projection.", + ).replace( + "model.embed_tokens_extend.audio_embed.", + "embed_tokens_extend.", + ).replace("model.embed_tokens_extend.image_embed.", + "vision_encoder.")) + # NOTE: this is deal with LoRA injection, where `base_layer` + # remains as the original layer in the model if name.endswith(".base_layer.weight"): name = name.replace(".base_layer.weight", ".weight") adjusted_weights[name] = weight - missing_keys, unexpected_keys = self.load_state_dict( - adjusted_weights, strict=False - ) + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, + strict=False) logger.debug("*** missing keys:") for key in missing_keys: logger.debug(key) @@ -1597,7 +1741,8 @@ def forward( input_ids = None inputs_embeds = None else: - # Each entry in this is a pair of audio_features and audio_embed lengths + # Each entry in this is a pair of audio_features and audio_embed + # lengths audio_input = self._parse_and_validate_audio_input(**kwargs) image_inputs = self._parse_and_validate_image_input(**kwargs) @@ -1607,26 +1752,25 @@ def forward( if has_audio: audio_projection_mode = 'vision' if has_image else 'speech' inputs_embeds = self._process_audio_input( - input_ids, audio_input, audio_projection_mode - ) + input_ids, audio_input, audio_projection_mode) if has_image: - dtype = self.vision_encoder.img_processor.embeddings.patch_embedding.weight.dtype + dtype = self.vision_encoder.img_processor.embeddings.\ + patch_embedding.weight.dtype pixel_values = image_inputs['pixel_values'].to(dtype) image_sizes = image_inputs['image_sizes'] image_attention_mask = image_inputs['image_attention_mask'] image_set_tensors = self.vision_encoder( - pixel_values, image_sizes, image_attention_mask - ) + pixel_values, image_sizes, image_attention_mask) if not has_audio: inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.merge_image_features_to_inputs_embeds( - input_ids, inputs_embeds, image_set_tensors - ) + input_ids, inputs_embeds, image_set_tensors) if has_image or has_audio: - # multi-modal input, we have set inputs_embeds properly in previous steps + # multi-modal input, we have set inputs_embeds properly in + # previous steps input_ids = None else: # text-only, we keep using original input_ids diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index c9a47ffb7e73..f9d4881c55e2 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -1,32 +1,34 @@ - +# SPDX-License-Identifier: Apache-2.0 # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) # but implemented by the Phi-Speech team #!/usr/bin/env python3 import abc -from functools import partial import math -import torch -import numpy as np +from functools import partial from typing import Callable, Dict, List, Literal, Optional, Union + +import numpy as np +import torch import torch.nn.functional as F -from torch import nn, Tensor -from transformers import PretrainedConfig -from torch.utils.checkpoint import checkpoint +from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper, - checkpoint_wrapper, - offload_wrapper, - CheckpointImpl, -) -from vllm.model_executor.models.phi4mm_utils import AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, get_offset, repeat, unfold_tensor, validate_checkpointing_config + CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper) from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel, -) + FullyShardedDataParallel) +from torch.utils.checkpoint import checkpoint +from transformers import PretrainedConfig + +from vllm.model_executor.models.phi4mm_utils import ( + AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, + MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, + adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, + get_offset, repeat, unfold_tensor, validate_checkpointing_config) _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + def encoder_checkpoint_wrapper( activation_checkpointing: Union[str, Dict], layer_cls: type, @@ -43,7 +45,8 @@ def encoder_checkpoint_wrapper( return lambda x: x if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get("module", "transformer") + target_layer_cls = activation_checkpointing.get( + "module", "transformer") if target_layer_cls.lower() == "transformer": target_layer_cls = ( "EncoderLayer", @@ -53,16 +56,11 @@ def encoder_checkpoint_wrapper( target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") checkpointing_interval = activation_checkpointing.get("interval", 1) offloading = activation_checkpointing.get("offload", False) - impl = ( - CheckpointImpl.REENTRANT - if activation_checkpointing.get("reentrant", True) - else CheckpointImpl.NO_REENTRANT - ) + impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( + "reentrant", True) else CheckpointImpl.NO_REENTRANT) - if ( - idx % checkpointing_interval == 0 - and layer_cls.__name__ in target_layer_cls - ): + if (idx % checkpointing_interval == 0 + and layer_cls.__name__ in target_layer_cls): if offloading: return offload_wrapper return partial(checkpoint_wrapper, checkpoint_impl=impl) @@ -84,8 +82,9 @@ class ConformerEncoderLayer(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a + channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value @@ -135,7 +134,7 @@ class ConformerEncoderLayer(nn.Module): if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. - attention_innner_dim: int, otional + attention_innner_dim: int, optional if equal to -1, attention dim for linears k/q/v is equal to d_model. otherwise attention_innner_dim is used. default -1. @@ -162,9 +161,11 @@ class ConformerEncoderLayer(nn.Module): and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional - if set to True, use pytorch's scaled dot product attention implementation in training. + if set to True, use pytorch's scaled dot product attention + implementation in training. attn_group_sizes: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), + the number of groups to use for attention, default 1 + (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attn_group_sizes < attention_heads = Grouped-Query Attention attn_group_sizes = attenion_heads = Multi-Query Attention @@ -210,18 +211,17 @@ def __init__( self.self_attn = encoder_checkpoint_wrapper( activation_checkpointing, MultiHeadedAttention, - )( - MultiHeadedAttention( - n_head, - d_model, - dropout_rate, - attention_innner_dim, - attention_glu_type, - bias_in_glu, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - group_size=attn_group_sizes, - ) - ) + )(MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + )) self.conv = ConvModule( d_model, ext_pw_out_channel, @@ -270,7 +270,8 @@ def forward( mask: torch.Tensor mask for x (batch, max_time_in) relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) @@ -291,6 +292,7 @@ def forward( return out, pos_k, pos_v, mask + class TransformerEncoderBase(abc.ABC, nn.Module): """The Base class for Transformer based encoders @@ -338,10 +340,12 @@ class TransformerEncoderBase(abc.ABC, nn.Module): padding index for input_layer=embed default -1 relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) + additional method-specific arguments can be provided (see + transformer_base.py) positional_dropout_rate: float, optional dropout rate after positional encoding. default 0.0 nemo_conv_settings: dict, optional @@ -354,9 +358,11 @@ class TransformerEncoderBase(abc.ABC, nn.Module): supraframe utts in batch. Default: none attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), + the number of groups to use for attention, default 1 + (Multi-Head Attention), 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention + 1 < attention_group_size < attention_heads = Grouped-Query + Attention attention_group_size = attenion_heads = Multi-Query Attention """ @@ -376,9 +382,8 @@ def __init__( relative_attention_bias_args=None, positional_dropout_rate=0.0, nemo_conv_settings=None, - conv2d_extra_padding: Literal[ - "feat", "feat_time", "none", True - ] = "none", + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", attention_group_size=1, encoder_embedding_config=None, ): @@ -413,33 +418,25 @@ def __init__( i not in nemo_conv_settings ), "{i} should be specified outside of the NeMo dictionary" - self.embed = NemoConvSubsampling( - **default_nemo_conv_settings, - ) + self.embed = NemoConvSubsampling(**default_nemo_conv_settings, ) else: raise ValueError("unknown input_layer: " + input_layer) - self.pos_emb = AbsolutePositionalEncoding( - attention_dim, positional_dropout_rate - ) + self.pos_emb = AbsolutePositionalEncoding(attention_dim, + positional_dropout_rate) self.relative_attention_bias_type = ( relative_attention_bias_args.get("type") - if relative_attention_bias_args - else None - ) + if relative_attention_bias_args else None) if self.relative_attention_bias_type == "t5": - assert ( - self.num_heads % self.attention_group_size == 0 - ), "attention_group_size must divide n_head" + assert (self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( self.num_heads // self.attention_group_size, max_distance=relative_attention_bias_args.get( - "t5_bias_max_distance", 1000 - ), + "t5_bias_max_distance", 1000), symmetric=relative_attention_bias_args.get( - "t5_bias_symmetric", False - ), + "t5_bias_symmetric", False), ) else: raise NotImplementedError @@ -447,12 +444,10 @@ def __init__( def post_init(self, init_model_config): pretrained_speech_encoder_path = init_model_config.get( - "pretrained_speech_encoder_path", None - ) + "pretrained_speech_encoder_path", None) if pretrained_speech_encoder_path: - model_state = torch.load( - pretrained_speech_encoder_path, map_location="cpu" - ) + model_state = torch.load(pretrained_speech_encoder_path, + map_location="cpu") encoder_state_dict = {} for k, v in model_state.items(): if "encoder." in k: @@ -465,43 +460,39 @@ def post_init(self, init_model_config): if not hasattr(self, "encoder_embedding"): self.encoder_embedding = MeanVarianceNormLayer( - self.encoder_embedding_config["input_size"] - ) + self.encoder_embedding_config["input_size"]) def compute_lens_change(self, feature_lens): """feature_lens: int return updated feature lens. - This used to return a different lambda function for each case that computed - the right thing. That does not work within Torchscript. If you really - need this to be faster, create nn.Module()-s for all the cases and return - one of them. Torchscript does support that. + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. + If you really need this to be faster, create nn.Module()-s for all + the cases and return one of them. Torchscript does support that. """ if self.input_layer == "nemo_conv": # Handle the special causal case subsampling_causal_cond = self.nemo_conv_settings.get( - "subsampling", "dw_striding" - ) in [ - "dw_striding", - "striding", - "striding_conv1d", - ] + "subsampling", "dw_striding") in [ + "dw_striding", + "striding", + "striding_conv1d", + ] is_causal = self.nemo_conv_settings.get("is_causal", False) if is_causal and subsampling_causal_cond: - lens_change = ( - torch.ceil(feature_lens / self.time_reduction).long() - if isinstance(feature_lens, Tensor) - else math.ceil(feature_lens / self.time_reduction) - ) + lens_change = (torch.ceil(feature_lens / + self.time_reduction).long() + if isinstance(feature_lens, Tensor) else + math.ceil(feature_lens / self.time_reduction)) feature_lens_remainder = feature_lens % self.time_reduction if isinstance(feature_lens, Tensor): lens_change[feature_lens_remainder != 1] += 1 elif feature_lens_remainder != 1: lens_change += 1 return lens_change - ceil_func = ( - math.ceil if isinstance(feature_lens, int) else torch.ceil - ) + ceil_func = (math.ceil + if isinstance(feature_lens, int) else torch.ceil) return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod @@ -518,16 +509,15 @@ def _chunk_size_selection(self, chunk_size=None, left_chunk=None): if isinstance(chunk_size, list): # Variable chunk size during training chunk_size_index = int( - torch.randint(low=0, high=len(chunk_size), size=(1,)) - ) + torch.randint(low=0, high=len(chunk_size), size=(1, ))) chunk_size_train_eff = chunk_size[chunk_size_index] if not isinstance(left_chunk, list): raise ValueError( - "Since chunk_size is a list, left_chunk must be a list" - ) + "Since chunk_size is a list, left_chunk must be a list") if len(left_chunk) != len(chunk_size): raise ValueError( - "The length of left_chunk must be the same as length of chunk_size." + "The length of left_chunk must be the same as length of "\ + "chunk_size." ) left_chunk_train_eff = left_chunk[chunk_size_index] else: @@ -558,14 +548,12 @@ def _position_embedding(self, input_tensor): pos_v = None if self.relative_attention_bias_layer is None: input_tensor = self.pos_emb( - input_tensor - ) # default to add abs sinusoid embedding + input_tensor) # default to add abs sinusoid embedding return pos_k, pos_v def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): - chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( - chunk_size, left_chunk - ) + chunk_size_train_eff, left_chunk_train_eff = \ + self._chunk_size_selection(chunk_size, left_chunk) # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] @@ -579,18 +567,17 @@ def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): chunk_start_idx = chunk_start_idx[:-1] chunk_start_idx = np.insert(chunk_start_idx, 0, 0) - enc_streaming_mask = ( - adaptive_enc_mask( - seq_len, chunk_start_idx, left_window=left_chunk_train_eff - ) - .unsqueeze(0) - .expand([batch_size, -1, -1]) - ) + enc_streaming_mask = (adaptive_enc_mask( + seq_len, chunk_start_idx, + left_window=left_chunk_train_eff).unsqueeze(0).expand( + [batch_size, -1, -1])) return enc_streaming_mask - def forward_embeddings( - self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None - ): + def forward_embeddings(self, + xs_pad, + masks, + chunk_size_nc=None, + left_chunk_nc=None): """Forwarding the inputs through the top embedding layers Args: @@ -598,31 +585,34 @@ def forward_embeddings( input tensor masks: torch.Tensor input mask - chunk_size_nc: (optional, default is None) chunk size for non-causal layers - left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers + chunk_size_nc: (optional, default is None) chunk size for + non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for + non-causal layers """ # pylint: disable=R0915 # get new lens. seq_len = int(self.compute_lens_change(xs_pad.shape[1])) if seq_len <= 0: raise ValueError( - f"""The squence length after time reduction is invalid: {seq_len}. - Your input feature is too short. Consider filtering out the very - short sentence from data loader""", - ) + f"""The sequence length after time reduction is invalid: + {seq_len}. Your input feature is too short. Consider + filtering out the very short sentence from data + loader""", ) batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask( - seq_len, batch_size, self.chunk_size, self.left_chunk - ) + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, + self.chunk_size, + self.left_chunk) if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.cuda() input_tensor = xs_pad - input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + input_tensor, masks = self._forward_embeddings_core( + input_tensor, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: @@ -634,8 +624,7 @@ def forward_embeddings( if chunk_size_nc is not None: enc_streaming_mask_nc = self._streaming_mask( - seq_len, batch_size, chunk_size_nc, left_chunk_nc - ) + seq_len, batch_size, chunk_size_nc, left_chunk_nc) if xs_pad.is_cuda: enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() if masks is not None: @@ -660,6 +649,7 @@ def get_offset(self): """ return get_offset(self.input_layer, self.time_reduction) + class ConformerEncoder(TransformerEncoderBase): """ConformerEncoder module. see original paper for more details: @@ -689,8 +679,9 @@ class ConformerEncoder(TransformerEncoderBase): left_chunk: int number of chunks used for masking in streaming mode. num_lang: int - This parameter is used to store the number of languages in the lang_dict, - only used for multiseed/multilingual models. default None. + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. + default None. attention_dim: int, optional attention dimension. default 256. attention_heads: int, optional @@ -752,7 +743,7 @@ class ConformerEncoder(TransformerEncoderBase): conv_activation: str, optional activation function used in ConvModule part of the conformer, default "relu". - conv_glu_type: str, otional + conv_glu_type: str, optional activation used use glu in depthwise_seperable_CNN, default "sigmoid" bias_in_glu: bool, optional @@ -787,15 +778,17 @@ class ConformerEncoder(TransformerEncoderBase): extra_layer_output_idx: int the layer index to be exposed. relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) + additional method-specific arguments can be provided (see + transformer_base.py) time_reduction: int optional time reduction factor default 4 - use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention - in training. + use_pt_scaled_dot_product_attention: whether to use pytorch scaled + dot product attention in training. Default: False nemo_conv_settings: dict, optional A dictionary of settings for NeMo Subsampling. @@ -803,7 +796,7 @@ class ConformerEncoder(TransformerEncoderBase): usage: nemo_conv_settings= { "subsampling": - dw_striding/striding/dw_striding_conv1d/striding_conv1d, + dw_striding/striding/dw_striding_conv1d/striding_conv1d, "conv_channels": int, "subsampling_conv_chunking_factor": int, "is_causal": True/False @@ -812,13 +805,16 @@ class ConformerEncoder(TransformerEncoderBase): Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True) Default: none - replication_pad_for_subsample_embedding: For batched-streaming decoding, use - "replication" padding for the cache at start of utterance. - Default: False + replication_pad_for_subsample_embedding: For batched-streaming + decoding, use "replication" padding for the cache at start of + utterance. + Default: False attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), + the number of groups to use for attention, default 1 + (Multi-Head Attention), 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention + 1 < attention_group_size < attention_heads = Grouped-Query + Attention attention_group_size = attenion_heads = Multi-Query Attention """ @@ -854,15 +850,14 @@ def __init__( # pylint: disable-all attention_glu_type="swish", export=False, extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], + extra_multi_layer_output_idxs=[], # noqa activation_checkpointing="", relative_attention_bias_args=None, time_reduction=4, use_pt_scaled_dot_product_attention=False, nemo_conv_settings=None, - conv2d_extra_padding: Literal[ - "feat", "feat_time", "none", True - ] = "none", + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", replication_pad_for_subsample_embedding=False, attention_group_size=1, encoder_embedding_config=None, @@ -889,49 +884,45 @@ def __init__( # pylint: disable-all self.num_lang = num_lang self.kernel_size = kernel_size self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( - self.embed - ) + self.embed) self.replication_pad_for_subsample_embedding: bool = ( - replication_pad_for_subsample_embedding - ) - assert ( - self.num_heads % attention_group_size == 0 - ), "attention_group_size must divide n_head" + replication_pad_for_subsample_embedding) + assert (self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" self.num_heads_k = self.num_heads // attention_group_size self.encoders = repeat( num_blocks, - lambda i: encoder_checkpoint_wrapper( - activation_checkpointing, ConformerEncoderLayer, i - )( - ConformerEncoderLayer( - d_model=attention_dim, - ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel=depthwise_seperable_out_channel, - depthwise_multiplier=depthwise_multiplier, - n_head=attention_heads, - d_ffn=linear_units, - ext_pw_kernel_size=ext_pw_kernel_size, - kernel_size=kernel_size, - dropout_rate=dropout_rate, - causal=causal, - batch_norm=batch_norm, - activation=activation, - chunk_se=chunk_se, - chunk_size=chunk_size, - conv_activation=conv_activation, - conv_glu_type=conv_glu_type, - bias_in_glu=bias_in_glu, - linear_glu_in_convm=linear_glu_in_convm, - attention_glu_type=attention_glu_type, - activation_checkpointing=attn_checkpointing( - activation_checkpointing, i - ), - export=export, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - attn_group_sizes=attention_group_size, - ) - ), + lambda i: encoder_checkpoint_wrapper(activation_checkpointing, + ConformerEncoderLayer, i) + (ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel= + depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing( + activation_checkpointing, i), + export=export, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + )), ) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs @@ -946,19 +937,19 @@ def init_relative_attention_bias(self, input_tensor): def calculate_hs_mask(self, xs_pad, device, mask): max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask( - max_audio_length, batch_size, self.chunk_size, self.left_chunk - ) + enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, + self.chunk_size, + self.left_chunk) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens - pad_mask = ( - torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) - < padding_length.unsqueeze(1) - ) + pad_mask = (torch.arange(0, max_audio_length, + device=device).expand(padding_length.size(0), + -1) + < padding_length.unsqueeze(1)) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask @@ -975,50 +966,59 @@ def forward(self, xs_pad, masks): """ xs_pad = self.encoder_embedding(xs_pad) input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( - xs_pad, masks - ) + xs_pad, masks) unfolded = False ori_bz, seq_len, D = input_tensor.shape - max_seq_len = 500 #maxium position for absolute positional encoding + max_seq_len = 500 #maximum position for absolute positional encoding if seq_len > max_seq_len: - # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + # audio sequence is longer than max_seq_len, unfold it into chunks + # of max_seq_len unfolded = True - # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + # the unfold op will drop residual frames, pad it to the multiple + # of max_seq_len if seq_len % max_seq_len > 0: chunk_pad_size = max_seq_len - (seq_len % max_seq_len) else: chunk_pad_size = 0 if chunk_pad_size > 0: - input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) + input_tensor_pad = F.pad(input_tensor, + (0, 0, 0, chunk_pad_size), "constant", + 0) input_tensor = input_tensor_pad.to(input_tensor.device) input_tensor = unfold_tensor(input_tensor, max_seq_len) if masks is not None: - # revise hs_mask here because the previous calculated hs_mask did not consider extra pad - subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() - masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor - masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = masks.squeeze( + 1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", + False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = \ + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1).bool() # unfold op does not support bool tensor else: masks_unfold = None - hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask - layer_emb = None + # layer_emb = None relative_attention_bias = self.init_relative_attention_bias( - input_tensor - ) + input_tensor) - _simplified_path = ( - self.extra_layer_output_idx == -1 - and relative_attention_bias is None - ) + _simplified_path = (self.extra_layer_output_idx == -1 + and relative_attention_bias is None) if _simplified_path: - input_tensor, *_ = self.encoders( - input_tensor, pos_k, pos_v, hs_mask - ) + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, + hs_mask) else: for i, layer in enumerate(self.encoders): input_tensor, _, _, _ = layer( @@ -1029,8 +1029,8 @@ def forward(self, xs_pad, masks): relative_attention_bias=relative_attention_bias, ) - if i == self.extra_layer_output_idx: - layer_emb = input_tensor + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor if unfolded: embed_dim = input_tensor.shape[-1] @@ -1044,6 +1044,7 @@ def forward(self, xs_pad, masks): def gradient_checkpointing_enable(self): pass + class WindowQformer(nn.Module): """Window-level Qformer""" @@ -1060,25 +1061,21 @@ def __init__( ): super().__init__() - self.decoders = nn.ModuleList( - [ - nn.TransformerDecoderLayer( - d_model=attention_dim, - nhead=attention_heads, - dim_feedforward=linear_units, - dropout=dropout_rate, - activation="relu", - batch_first=True, - norm_first=normalize_before, # TODO need to verify - ) - for _ in range(num_blocks) - ] - ) + self.decoders = nn.ModuleList([ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) for _ in range(num_blocks) + ]) self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) - self.after_norm = ( - nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None - ) + self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) + if normalize_before else None) self.window_size = window_size self.gradient_checkpointing_enable = False @@ -1096,9 +1093,8 @@ def forward(self, audio_embed, mask, embed_len=None): # audio_embed: N x D x 1 x T => N x DK x T' padding = audio_embed.shape[-1] % self.window_size if padding > 0: - audio_embed = F.pad( - audio_embed, (0, self.window_size - padding), "constant", 0 - ) + audio_embed = F.pad(audio_embed, (0, self.window_size - padding), + "constant", 0) embed_chunk = F.unfold( audio_embed[..., None, :], @@ -1125,9 +1121,10 @@ def forward(self, audio_embed, mask, embed_len=None): use_reentrant=True, ) else: - q = layer( - tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask - ) + q = layer(tgt=q, + memory=embed_chunk, + tgt_mask=None, + memory_mask=mask) if self.after_norm is not None: q = self.after_norm(q) @@ -1139,6 +1136,7 @@ def forward(self, audio_embed, mask, embed_len=None): return out, embed_len + class AudioEmbedding(nn.Module): """Image embedding.""" @@ -1146,16 +1144,12 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM - hidden_size = ( - config.n_embd if hasattr(config, "n_embd") else config.hidden_size - ) + hidden_size = (config.n_embd + if hasattr(config, "n_embd") else config.hidden_size) if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): - embd_drop = ( - config.embd_pdrop - if hasattr(config, "embd_pdrop") - else config.embed_pdrop - ) + embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop") + else config.embed_pdrop) self.drop = nn.Dropout(embd_drop) else: self.drop = None @@ -1167,49 +1161,30 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: ) self.layer_idx = -2 - # if isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == 'whisper': - # model_path = config.audio_processor.get('pretrained_model_path', None) - # whisper_model = WhisperModel.from_pretrained(model_path) - - # self.encoder = whisper_model.encoder - # n_mels = self.encoder.num_mel_bins - # audio_dim_out = self.encoder.layers[0].embed_dim - # elif isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == "w2vbert2": - # audio_processor_path = config.audio_processor.get("model_path", "facebook/w2v-bert-2.0") - # self.encoder = Wav2Vec2BertModel.from_pretrained(audio_processor_path) - # audio_dim_out = self.encoder.config.hidden_size - # self.layer_idx = config.audio_processor.get("layer", 18) - # self.encoder.config.apply_spec_augment = False - # self.encoder.config.mask_time_prob = 0 - # self.encoder.config.output_hidden_states = True - # n_mels = 160 - if ( - isinstance(config.audio_processor, dict) - and config.audio_processor.get("name", None) == "cascades" - ): + if (isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades"): encoder_config = config.audio_processor.get("config", None) assert encoder_config is not None self.encoder = ConformerEncoder(**encoder_config) # fake initialization, create encoder_embedding layer only so that - # in decoding, all parameters can be loaded in from_pretrained_function - # in training, we do post init after from_pretrained function to make sure the correct initialization + # in decoding, all parameters can be loaded in + # from_pretrained_function in training, we do post init after + # from_pretrained function to make sure the correct initialization self.encoder.post_init({}) audio_dim_out = encoder_config["attention_dim"] n_mels = encoder_config["input_size"] else: - raise NotImplementedError(f"") + raise NotImplementedError("") - assert ( - audio_dim_out is not None - ), "Remember to set values for audio_dim_out" + assert (audio_dim_out + is not None), "Remember to set values for audio_dim_out" self.audio_dim_out = audio_dim_out self.audio_dim_in = n_mels - self.freeze_audio_processor = kwargs.get( - "freeze_audio_processor", False - ) + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", + False) self.downsample_rate = kwargs.get("downsample_rate", 1) @@ -1221,9 +1196,8 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: self.qformer = None if kwargs.get("use_conv_downsample", False): - assert ( - self.qformer is None - ), "don't support use qformer and conv downsample together" + assert (self.qformer is None + ), "don't support use qformer and conv downsample together" nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) default_nemo_conv_settings = { "subsampling": "dw_striding", @@ -1243,15 +1217,12 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: i not in nemo_conv_settings ), "{i} should be specified outside of the NeMo dictionary" - self.conv_ds = NemoConvSubsampling( - **default_nemo_conv_settings, - ) + self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, ) else: self.conv_ds = None enable_gradient_checkpointing = kwargs.get( - "enable_gradient_checkpointing", False - ) + "enable_gradient_checkpointing", False) if enable_gradient_checkpointing: self.encoder.gradient_checkpointing_enable() @@ -1266,34 +1237,30 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: # (do not use image_projection and image_proj_norm) dim_projection = hidden_size depth = 2 - self.linear_downsample_rate = ( - 1 if (self.qformer or self.conv_ds) else self.downsample_rate - ) + self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds) + else self.downsample_rate) layers = [ - nn.Linear( - audio_dim_out * self.linear_downsample_rate, dim_projection - ) + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) ] for _ in range(1, depth): layers.extend( - [nn.GELU(), nn.Linear(dim_projection, dim_projection)] - ) + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) self.audio_projection = nn.Sequential(*layers) - # NOTE vision-speech tasks use a seperate projection layer + # NOTE vision-speech tasks use a separate projection layer layers = [ - nn.Linear( - audio_dim_out * self.linear_downsample_rate, dim_projection - ) + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) ] for _ in range(1, depth): layers.extend( - [nn.GELU(), nn.Linear(dim_projection, dim_projection)] - ) + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) self.audio_projection_for_vision = nn.Sequential(*layers) else: raise NotImplementedError( - f"projection_cls = {projection_cls}, not implemented" - ) + f"projection_cls = {projection_cls}, not implemented") # TODO: audio sequence compression - Qformer self.vocab_size = config.vocab_size @@ -1303,9 +1270,8 @@ def __init__(self, config: PretrainedConfig, **kwargs) -> None: def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: self.input_embeds = input_embeds - def set_audio_embed_sizes( - self, audio_embed_sizes: torch.LongTensor - ) -> None: + def set_audio_embed_sizes(self, + audio_embed_sizes: torch.LongTensor) -> None: self.audio_embed_sizes = audio_embed_sizes def get_audio_features( @@ -1317,13 +1283,11 @@ def get_audio_features( if self.freeze_audio_processor: with torch.no_grad(): - audio_features, masks = self.encoder( - input_embeds, audio_attention_mask - ) + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) else: - audio_features, masks = self.encoder( - input_embeds, audio_attention_mask - ) + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) if self.qformer is not None: audio_features, _ = self.qformer(audio_features, mask=None) @@ -1357,7 +1321,10 @@ def get_audio_features( elif audio_projection_mode == 'vision': audio_set_tensor = self.audio_projection_for_vision(audio_features) else: - raise ValueError(f"audio_projection_mode = {audio_projection_mode} not implemented") + raise ValueError( + f"audio_projection_mode = {audio_projection_mode} not "\ + "implemented" + ) return audio_set_tensor @@ -1374,21 +1341,22 @@ def forward( input_embeds: audio features (B, T, D) B: num audios in a sequence """ assert input_embeds is not None and len(input_embeds) == len( - audio_embed_sizes - ) + audio_embed_sizes) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) with torch.no_grad(): - positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(as_tuple=False) + positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=False) if not isinstance(input_embeds, list): input_embeds = [input_embeds] audio_projection_mode = kwargs.get("audio_projection_mode", "speech") audio_set_tensor = [ - self.get_audio_features(input_embed, audio_projection_mode=audio_projection_mode) + self.get_audio_features( + input_embed, audio_projection_mode=audio_projection_mode) for input_embed in input_embeds ] @@ -1396,38 +1364,39 @@ def forward( input_ids.clamp_min_(0).clamp_max_(self.vocab_size) if "wte" in kwargs: - # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. + # we use the token embedding layer from the huggingface model, this + # is REQUIRED to make sure we are using the loaded weights. hidden_states = kwargs["wte"](input_ids) else: - # otherwise, we use token embedding in pretrained mixformer from phi team + # otherwise, we use token embedding in pretrained mixformer from + # phi team hidden_states = self.wte(input_ids) if len(positions.tolist()) > 0: assert sum(audio_embed_sizes) == len( positions - ), "please ensure the encoder outputs have the same length as defined in input_ids!" + ), "please ensure the encoder outputs have the same length as"\ + " defined in input_ids!" idx = 0 for i in range(len(audio_embed_sizes)): cnt = audio_embed_sizes[i] assert audio_set_tensor[i].shape[0] == 1 hidden_states[ positions[idx, 0], - positions[idx, 1] : positions[idx, 1] + cnt, - ] = ( - audio_set_tensor[i][0, : audio_embed_sizes[i], :] - .to(hidden_states.dtype) - .to(hidden_states.device) - ) + positions[idx, 1]:positions[idx, 1] + cnt, + ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( + hidden_states.dtype).to(hidden_states.device)) idx += cnt else: if self.training: - # hidden_states[:, 0:img_set_tensor.shape[0]] = hidden_states[:, 0:img_set_tensor.shape[0]] + 0 * img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) - hidden_states[:, 0:1] = hidden_states[ - :, 0:1 - ] + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype).to( - hidden_states.device - ) + # hidden_states[:, 0:img_set_tensor.shape[0]] = + # hidden_states[:, 0:img_set_tensor.shape[0]] + + # 0 * img_set_tensor.to(hidden_states.dtype) + # .to(hidden_states.device) + hidden_states[:, 0:1] = hidden_states[:, 0:1] + \ + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\ + .to(hidden_states.device) if self.drop is not None: hidden_states = self.drop(hidden_states) diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 787e4508419d..16b62c60836e 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -1,20 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) # but implemented by the Phi-Speech team #!/usr/bin/env python3 -from functools import partial import math -from typing import Optional, Tuple, Union, Union, Dict, Callable +from functools import partial +from typing import Callable, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F -from torch import nn, Tensor +from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - offload_wrapper, - CheckpointImpl, -) + CheckpointImpl, checkpoint_wrapper, offload_wrapper) + class Block(nn.Module): """Block abstract module""" @@ -51,9 +50,11 @@ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): The function is very important for Transformer Transducer Streaming mode Args: xs_len (int): sequence length - chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + It also supports adaptive chunk size [0,10,15,45] left_window (int): how many left chunks can be seen - right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + right_window (int): how many right chunks can be seen. It is used for + chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model Torch 1.0.1 @@ -65,25 +66,21 @@ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): [False., True., True., False.], [False., False., True., True.]]) """ - chunk_start_idx = torch.Tensor( - chunk_start_idx - ).long() # first idx of each chunk, such as [0,18,36,48]. + chunk_start_idx = torch.Tensor(chunk_start_idx).long( + ) # first idx of each chunk, such as [0,18,36,48]. start_pad = torch.nn.functional.pad( - chunk_start_idx, (1, 0) - ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + chunk_start_idx, + (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] - seq_range = torch.arange(0, x_len).unsqueeze( - -1 - ) # seq_range size: [x_len, 1] - idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ - :, 1 - ] # idx size: [x_len] - boundary = end_pad[idx] # boundary size: [x_len] - seq_range_expand = ( - torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - ) # seq_range_expand size [x_len, x_len] + seq_range = torch.arange(0, + x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & + (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + # boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] @@ -227,27 +224,25 @@ def forward(self, x): x: torch.Tensor input tensor """ - # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + # to be consistent with GLULinear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case x = x.permute([0, 2, 1]) x = self.ext_pw_conv_1d(x) if self.glu_type == "bilinear": if self.bias_in_glu: - x = (x[:, 0 : self.output_dim, :] + self.b1) * ( - x[:, self.output_dim : self.output_dim * 2, :] + self.b2 - ) + x = (x[:, 0:self.output_dim, :] + self.b1) * ( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) else: - x = (x[:, 0 : self.output_dim, :]) * ( - x[:, self.output_dim : self.output_dim * 2, :] - ) + x = (x[:, 0:self.output_dim, :]) * ( + x[:, self.output_dim:self.output_dim * 2, :]) else: if self.bias_in_glu: - x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( - x[:, self.output_dim : self.output_dim * 2, :] + self.b2 - ) + x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) else: - x = (x[:, 0 : self.output_dim, :]) * self.glu_act( - x[:, self.output_dim : self.output_dim * 2, :] - ) + x = (x[:, 0:self.output_dim, :]) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :]) x = x.permute([0, 2, 1]) return x @@ -262,8 +257,9 @@ class DepthWiseSeperableConv1d(nn.Module): input_dim: int input channel size. depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out + of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. kernel_size: int kernel_size @@ -332,7 +328,8 @@ class ConvModule(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. ext_pw_kernel_size: int @@ -421,12 +418,7 @@ def __init__( self.export = export if causal: - if export: # Inference only. - padding = 0 # A cache is concatenated to the left. No padding in the kernel. - else: - # Training only. Padding will be added symmetrically on both sides. - # After convolution, clip off kernel_size-1 points on the right. - padding = kernel_size - 1 + padding = 0 if export else kernel_size - 1 else: padding = (kernel_size - 1) // 2 @@ -440,12 +432,12 @@ def __init__( if depthwise_seperable_out_channel != 0: if input_dim != depthwise_seperable_out_channel: - self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + self.ln2 = nn.Linear(depthwise_seperable_out_channel, + input_dim) else: if depthwise_multiplier != 1: - self.ln2 = nn.Linear( - input_dim * depthwise_multiplier, input_dim - ) + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, + input_dim) def _add_ext_pw_layer(self): """ @@ -454,8 +446,7 @@ def _add_ext_pw_layer(self): of the conformer. """ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( - nn.Identity() - ) # jit hacks. + nn.Identity()) # jit hacks. self.squeeze_excitation = nn.Identity() # jit. self.apply_ln1 = self.fix_len1 = False # jit. @@ -520,7 +511,7 @@ def forward(self, x): if self.ext_pw_out_channel != 0: x = self.glu(x) if self.causal and self.ext_pw_kernel_size > 1: - x = x[:, : -(self.ext_pw_kernel_size - 1), :] + x = x[:, :-(self.ext_pw_kernel_size - 1), :] if self.apply_ln1: x = self.ln1(x) else: @@ -532,7 +523,7 @@ def forward(self, x): x = self.dw_sep_conv_1d(x) if self.causal and self.kernel_size > 1: - x = x[:, :, : -(self.kernel_size - 1)] + x = x[:, :, :-(self.kernel_size - 1)] if hasattr(self, "ln2"): x = x.permute([0, 2, 1]) x = self.ln2(x) @@ -544,7 +535,7 @@ def forward(self, x): if self.ext_pw_out_channel != 0: x = self.ext_pw_conv_1d(x) if self.fix_len1: - x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + x = x[:, :, :-(self.ext_pw_kernel_size - 1)] if self.apply_ln1: x = x.permute([0, 2, 1]) @@ -665,7 +656,8 @@ def _pre_hook( Note: We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + Therefore, we remove the item "pe" from `state_dict` for backward + compatibility. """ k = prefix + "pe" @@ -675,43 +667,53 @@ def _pre_hook( class T5RelativeAttentionLogitBias(nn.Module): """ - This module implements the relative position bias described in Section 2.1 of - the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + This module implements the relative position bias described in Section + 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf The Huggingface implementation is used as a reference - https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 + https://github.com/huggingface/transformers/blob/v4.30.0/src/ + transformers/models/t5/modeling_t5.py#L435 - Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position - of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. + Modifies attention as Q*K^T + B, where B is a learned scalar bias based + on relative position of the query and key. It is HxNxN, where H is the + number of heads, N is the sequence length. I've made these modifications to the original T5 bias: - - Skipping of the bucketing step. Original T5 bias converted rel position distances into - logarithmically increasing buckets. This is supposed to help with length generalization. - - I just directly use rel position index as bias values, as we don't need length - generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. - - I've also extended it so that biases can be asymmetric, the default implementation treats - L->R and R->L the same. Asymmetric was found to yield better results in my experiments. + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is + supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), + and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to + yield better results in my experiments. Args: num_heads: int Number of attention heads num_buckets: int - Number of buckets to use for relative attention bias. This is the size of the learnable - bias parameter. Bucketing is not yet supported, so this defaults to -1 which means - no bucketing is used (max_distance determines size of bias param). + Number of buckets to use for relative attention bias. This is the + size of the learnable bias parameter. Bucketing is not yet + supported, so this defaults to -1 which means no bucketing is + used (max_distance determines size of bias param). max_distance: int - Maximum distance to use for relative attention bias. With num_buckets=-1, this directly - controls the max size of the bias parameter. When num_buckets > 0 is supported, this - will control the maximum distance for logarithmic bucketing after which all positions - are in the same bucket. + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all + positions are in the same bucket. symmetric: bool - Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias - params to distinguish L->R from R->L. This was found to be better for the encoder. + Whether to use symmetric or asymmetric biases. symmetric=False uses + 2x number of bias params to distinguish L->R from R->L. This was + found to be better for the encoder. """ - def __init__( - self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False - ): + def __init__(self, + num_heads, + num_buckets=-1, + max_distance=1000, + symmetric=False): super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets @@ -722,8 +724,7 @@ def __init__( self.num_buckets = max_distance else: raise NotImplementedError( - "T5 attention bias with bucketed positions is not yet tested" - ) + "T5 attention bias with bucketed positions is not yet tested") if not self.symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) @@ -731,20 +732,19 @@ def __init__( def forward(self, x): # instantiate bias compatible with shape of x maxpos = x.size(1) - context_position = torch.arange( - maxpos, device=x.device, dtype=torch.long - )[:, None] - memory_position = torch.arange( - maxpos, device=x.device, dtype=torch.long - )[None, :] + context_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[:, None] + memory_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[None, :] relative_position = memory_position - context_position - # clipping to a maximum distance using ops that play well with ONNX export + # clipping to a maximum distance using ops that play well with ONNX + # export relative_position = relative_position.masked_fill( - relative_position < -self.max_distance, -self.max_distance - ) + relative_position < -self.max_distance, -self.max_distance) relative_position = relative_position.masked_fill( - relative_position > self.max_distance - 1, self.max_distance - 1 - ) + relative_position > self.max_distance - 1, self.max_distance - 1) # mapping from relative position to index in the bias parameter if self._skip_bucketing: @@ -758,45 +758,42 @@ def forward(self, x): t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( - 0 - ) # [1, H, L, L] + 0) # [1, H, L, L] return t5_rel_att_bias def _bucket_relative_position(self, relative_position): - # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference - # this also needs to be extended to support asymmetric +/- ve positions + # this is a placeholder (isn't tested, likely buggy) using HuggingFace + # implem as a reference this also needs to be extended to support + # asymmetric +/- ve positions relative_buckets = 0 if not self.causal: - num_buckets //= 2 + self.num_buckets //= 2 relative_buckets += (relative_position > 0).to( - torch.long - ) * num_buckets + torch.long) * self.num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 + max_exact = self.num_buckets // 2 is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(self.max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) + torch.log(relative_position.float() / max_exact) / + math.log(self.max_distance / max_exact) * + (self.num_buckets - max_exact)).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), + torch.full_like(relative_position_if_large, self.num_buckets - 1), ) - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) return relative_buckets @@ -831,17 +828,15 @@ def extend_pe(self, x): Args: x: torch.Tensor """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return + if self.pe is not None and self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return pe = torch.zeros(x.size(1), self.d_model) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) @@ -859,7 +854,7 @@ def forward(self, x: torch.Tensor): """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1)] + x = x * self.xscale + self.pe[:, :x.size(1)] return self.dropout(x) @@ -867,7 +862,7 @@ def forward(self, x: torch.Tensor): class MeanVarianceNormLayer(nn.Module): """Mean/variance normalization layer. - Will substract mean and multiply input by inverted standard deviation. + Will subtract mean and multiply input by inverted standard deviation. Typically used as a very first layer in a model. Args: @@ -892,16 +887,22 @@ def forward(self, input_: Tensor) -> Tensor: """ return (input_ - self.global_mean) * self.global_invstd + class CausalConv1D(nn.Conv1d): """ - A causal version of nn.Conv1d where each step would have limited access to locations on its right or left + A causal version of nn.Conv1d where each step would have limited access to + locations on its right or left All arguments are the same as nn.Conv1d except padding. - If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. + If padding is set None, then paddings are set automatically to make it a + causal convolution where each location would not see any steps on its right. - If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. - It would make it possible to control the number of steps to be accessible on the right and left. - This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). + If padding is set as a list (size of 2), then padding[0] would be used as + left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible + on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should + be equal to (kernel_size - 1). """ def __init__( @@ -925,16 +926,12 @@ def __init__( else: if stride != 1 and padding != kernel_size - 1: raise ValueError( - "No striding allowed for non-symmetric convolutions!" - ) + "No striding allowed for non-symmetric convolutions!") if isinstance(padding, int): self._left_padding = padding self._right_padding = padding - elif ( - isinstance(padding, list) - and len(padding) == 2 - and padding[0] + padding[1] == kernel_size - 1 - ): + elif (isinstance(padding, list) and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1): self._left_padding = padding[0] self._right_padding = padding[1] else: @@ -964,10 +961,10 @@ def update_cache(self, x, cache=None): new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: - next_cache = new_x[:, :, : -self.cache_drop_size] + next_cache = new_x[:, :, :-self.cache_drop_size] else: next_cache = new_x - next_cache = next_cache[:, :, -cache.size(-1) :] + next_cache = next_cache[:, :, -cache.size(-1):] return new_x, next_cache def forward(self, x, cache=None): @@ -981,8 +978,10 @@ def forward(self, x, cache=None): class CausalConv2D(nn.Conv2d): """ - A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down - All arguments are the same as nn.Conv2d except padding which should be set as None + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None """ def __init__( @@ -1001,8 +1000,7 @@ def __init__( ) -> None: if padding is not None: raise ValueError( - "Argument padding should be set to None for CausalConv2D." - ) + "Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 @@ -1046,43 +1044,48 @@ def forward( class NemoConvSubsampling(torch.nn.Module): """Convlutional subsampling module, taken from NeMo ASR - (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a + 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) - Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for - Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + et al. (https://ieeexplore.ieee.org/document/8462506) - Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, - and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce - FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. + Compared with the EncoderConv2D (`input_layer: custom`), this is a + much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. + Moreover, depthwise convolutions are used to reduce FLOPs, but the first + layer is kept as a regular convolution so as not to degrade accuracy. - `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions - after the first layer, whereas the former does not. + `Striding` and `dw_striding` are the same except that the latter uses + depthwise convolutions after the first layer, whereas the former does not. Args: subsampling_factor (int): Time reduction factor feat_in (int): size of the input features feat_out (int): size of the output features subsampling (str): The subsampling technique, choose from - {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} - conv_channels (int): Number of channels for the convolution layers, default is 256. - subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) - 1 (auto) or a power of 2. Default is 1 + {"striding", "dw-striding", "striding_conv1d", + "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, + default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which + can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 activation (Module): activation function, default is nn.ReLU() - is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access - to locations on its right or left + is_causal (bool): whether to use causal Conv1/2D, where each step will + have limited access to locations on its right or left """ def __init__( - self, - feat_in, - feat_out, - subsampling_factor=4, - subsampling="dw_striding", - conv_channels=256, - subsampling_conv_chunking_factor=1, - activation=nn.ReLU(), - is_causal=False, + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), # noqa: B008 + is_causal=False, ): super().__init__() self._subsampling = subsampling @@ -1101,15 +1104,15 @@ def __init__( "striding_conv1d", ) - if ( - subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0 - ): + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" ) - self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + self.subsampling_conv_chunking_factor = \ + subsampling_conv_chunking_factor in_channels = 1 layers = [] @@ -1137,8 +1140,7 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=None, - ) - ) + )) else: layers.append( torch.nn.Conv2d( @@ -1147,8 +1149,7 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - ) - ) + )) in_channels = conv_channels layers.append(activation) @@ -1162,8 +1163,7 @@ def __init__( stride=self._stride, padding=None, groups=in_channels, - ) - ) + )) else: layers.append( torch.nn.Conv2d( @@ -1173,8 +1173,7 @@ def __init__( stride=self._stride, padding=self._left_padding, groups=in_channels, - ) - ) + )) layers.append( torch.nn.Conv2d( @@ -1184,8 +1183,7 @@ def __init__( stride=1, padding=0, groups=1, - ) - ) + )) layers.append(activation) in_channels = conv_channels @@ -1212,8 +1210,7 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=None, - ) - ) + )) else: layers.append( torch.nn.Conv2d( @@ -1222,8 +1219,7 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - ) - ) + )) layers.append(activation) in_channels = conv_channels @@ -1248,30 +1244,22 @@ def __init__( layers.append( CausalConv1D( in_channels=in_channels, - out_channels=( - feat_out - if self._sampling_num == i + 1 - else conv_channels - ), + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), kernel_size=self._kernel_size, stride=self._stride, padding=None, - ) - ) + )) else: layers.append( torch.nn.Conv1d( in_channels=in_channels, - out_channels=( - feat_out - if self._sampling_num == i + 1 - else conv_channels - ), + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - ) - ) + )) layers.append(activation) in_channels = conv_channels @@ -1286,8 +1274,30 @@ def __init__( self._right_padding = (self._kernel_size - 1) // 2 # Layer 1 - layers.extend( - [ + layers.extend([ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == 1 else + conv_channels), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ]) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend([ torch.nn.Conv1d( in_channels=in_channels, out_channels=in_channels, @@ -1298,46 +1308,14 @@ def __init__( ), torch.nn.Conv1d( in_channels=in_channels, - out_channels=( - feat_out - if self._sampling_num == 1 - else conv_channels - ), + out_channels=(feat_out if self._sampling_num == i + + 2 else conv_channels), kernel_size=1, stride=1, padding=0, groups=1, ), - ] - ) - in_channels = conv_channels - layers.append(activation) - - for i in range(self._sampling_num - 1): - layers.extend( - [ - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ), - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=( - feat_out - if self._sampling_num == i + 2 - else conv_channels - ), - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ] - ) + ]) layers.append(activation) in_channels = conv_channels @@ -1354,9 +1332,8 @@ def __init__( ceil_mode=self._ceil_mode, repeat_num=self._sampling_num, ) - self.out = torch.nn.Linear( - conv_channels * int(out_length), feat_out - ) + self.out = torch.nn.Linear(conv_channels * int(out_length), + feat_out) self.conv2d_subsampling = True elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: self.out = None @@ -1384,33 +1361,26 @@ def forward(self, x, mask): Returns: x: torch.Tensor - Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) + Resulting tensor from subsampling (B, T // + time_reduction_factor, feat_out) pad_mask: torch.Tensor - tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) + tensor of padded hidden state sequences (B, 1, T // + time_reduction_factor) """ - # Unsqueeze Channel Axis - if self.conv2d_subsampling: - x = x.unsqueeze(1) - # Transpose to Channel First mode - else: - x = x.transpose(1, 2) + x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) # split inputs if chunking_factor is set - if ( - self.subsampling_conv_chunking_factor != -1 - and self.conv2d_subsampling - ): + if (self.subsampling_conv_chunking_factor != -1 + and self.conv2d_subsampling): if self.subsampling_conv_chunking_factor == 1: - # if subsampling_conv_chunking_factor is 1, we split only if needed - # avoiding a bug / feature limiting indexing of tensors to 2**31 + # if subsampling_conv_chunking_factor is 1, we split only + # if needed. + # avoiding a bug / feature limiting indexing of tensors + # to 2**31. # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = ( - 2**31 / self._conv_channels * self._stride * self._stride - ) - if torch.numel(x) > x_ceil: - need_to_split = True - else: - need_to_split = False + x_ceil = (2**31 / self._conv_channels * self._stride * + self._stride) + need_to_split = torch.numel(x) > x_ceil else: # if subsampling_conv_chunking_factor > 1 we always split need_to_split = True @@ -1445,8 +1415,7 @@ def forward(self, x, mask): feature_lens_remainder = feature_lens % self.subsampling_factor padding_length[feature_lens_remainder != 1] += 1 pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( - padding_length.size(0), -1 - ) < padding_length.unsqueeze(1) + padding_length.size(0), -1) < padding_length.unsqueeze(1) return x, pad_mask.unsqueeze(1) def reset_parameters(self): @@ -1455,28 +1424,27 @@ def reset_parameters(self): with torch.no_grad(): # init conv scale = 1.0 / self._kernel_size - dw_max = (self._kernel_size**2) ** -0.5 + dw_max = (self._kernel_size**2)**-0.5 pw_max = self._conv_channels**-0.5 torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) for idx in range(2, len(self.conv), 3): - torch.nn.init.uniform_( - self.conv[idx].weight, -dw_max, dw_max - ) - torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) - torch.nn.init.uniform_( - self.conv[idx + 1].weight, -pw_max, pw_max - ) - torch.nn.init.uniform_( - self.conv[idx + 1].bias, -pw_max, pw_max - ) - - # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 - fc_scale = ( - self._feat_out * self._feat_in / self._sampling_num - ) ** -0.5 + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, + pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, + pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ + # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ + # src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / + self._sampling_num)**-0.5 torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) @@ -1500,17 +1468,16 @@ def conv_split_by_batch(self, x): return x, False return ( - torch.cat( - [ - self.conv(chunk) - for chunk in torch.split(x, new_batch_size, 0) - ] - ), + torch.cat([ + self.conv(chunk) + for chunk in torch.split(x, new_batch_size, 0) + ]), True, ) def conv_split_by_channel(self, x): - """For dw convs, tries to split input by time, run conv and concat results""" + """For dw convs, tries to split input by time, run conv and concat + results""" x = self.conv[0](x) # full conv2D x = self.conv[1](x) # activation @@ -1520,7 +1487,8 @@ def conv_split_by_channel(self, x): if self.subsampling_conv_chunking_factor > 1: cf = self.subsampling_conv_chunking_factor else: - # avoiding a bug / feature limiting indexing of tensors to 2**31 + # avoiding a bug / feature limiting indexing of tensors + # to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) cf = 2**p @@ -1533,9 +1501,8 @@ def conv_split_by_channel(self, x): if new_t == 0: new_t = 1 - x = self.channel_chunked_conv( - self.conv[i * 3 + 2], new_c, x - ) # conv2D, depthwise + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, + x) # conv2D, depthwise # splitting pointwise convs by time x = torch.cat( @@ -1568,8 +1535,8 @@ def channel_chunked_conv(self, conv, chunk_size, x): ) ch_out = nn.functional.conv2d( chunk, - conv.weight[ind : ind + step, :, :, :], - bias=conv.bias[ind : ind + step], + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], stride=self._stride, padding=0, groups=step, @@ -1577,8 +1544,8 @@ def channel_chunked_conv(self, conv, chunk_size, x): else: ch_out = nn.functional.conv2d( chunk, - conv.weight[ind : ind + step, :, :, :], - bias=conv.bias[ind : ind + step], + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], stride=self._stride, padding=self._left_padding, groups=step, @@ -1589,33 +1556,31 @@ def channel_chunked_conv(self, conv, chunk_size, x): return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor( - self, subsampling_conv_chunking_factor: int - ): - if ( - subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0 - ): + self, subsampling_conv_chunking_factor: int): + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" ) self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor -def calc_length( - lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1 -): - """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" +def calc_length(lengths, + all_paddings, + kernel_size, + stride, + ceil_mode, + repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or + max pooling layer""" add_pad: float = all_paddings - kernel_size one: float = 1.0 for i in range(repeat_num): - lengths = ( - torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one - ) - if ceil_mode: - lengths = torch.ceil(lengths) - else: - lengths = torch.floor(lengths) + lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + + one) + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) return lengths.to(dtype=torch.int) @@ -1669,15 +1634,15 @@ def masked_softmax( mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0 - ) # (batch, head, time1, time2) + mask, 0.0) # (batch, head, time1, time2) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) return attn class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer with optional relative position embedding and GLU. + """Multi-Head Attention layer with optional relative position embedding + and GLU. Args: n_head: int @@ -1695,12 +1660,14 @@ class MultiHeadedAttention(nn.Module): it can be different from the input dimension n_feat. default: -1 (equal to n_feat). use_pt_scaled_dot_product_attention: bool, optional - if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT - be used in ONNX decoding due to a lack of support. In that case, we use the original - attention implementation, which shows no regression. + if set True, use pytorch scaled dot product attention in training. + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention + implementation, which shows no regression. default: False. n_value: int, optional - if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. + if set to values other than -1, use a different dimension for + value. With the default value (i.e. -1), it is backward compatible. group_size: int, optional. must divide `n_head` if group_size > 1: GQA if group_size = 1: MHA @@ -1748,15 +1715,14 @@ def __init__( self.dropout = nn.Dropout(p=dropout_rate) self.dropout_rate = dropout_rate self.use_pt_scaled_dot_product_attention = ( - use_pt_scaled_dot_product_attention - ) + use_pt_scaled_dot_product_attention) if use_pt_scaled_dot_product_attention and group_size > 1: raise ValueError("Cannot use PT Scaled Attention with GQA") # Torchscript eager quantization. Note that these functions below are - # NOOPs and have very little impact on performance unless quantization is - # enabled. + # NOOPs and have very little impact on performance unless quantization + # is enabled. self.quant_q = torch.ao.quantization.QuantStub() self.quant_x = torch.ao.quantization.QuantStub() self.dequant = torch.ao.quantization.DeQuantStub() @@ -1788,30 +1754,24 @@ def forward( mask: torch.Tensor mask tensor (batch, time1, time2) relative_attention_bias: torch.Tensor - bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) """ n_batch = query.size(0) - q = self.linear_q(query).view( - n_batch, -1, self.h, self.d_k - ) # (b, t, d) - k = self.linear_k(key).view( - n_batch, -1, self.h_k, self.d_k - ) # (b, t, d) + q = self.linear_q(query).view(n_batch, -1, self.h, + self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, + self.d_k) # (b, t, d) v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) - q = ( - q.transpose(1, 2) - if self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting() - else q.transpose(1, 2) * self.inv_sqrt_d_k - ) + q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting() else q.transpose(1, 2) * + self.inv_sqrt_d_k) k = k.transpose(1, 2) # (batch, head_k, time2, d_k) v = v.transpose(1, 2) # (batch, head_k, time2, d_k) - if ( - self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting() - ): + if (self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting()): attn_mask = None if mask is not None: mask = mask.unsqueeze(1) @@ -1822,9 +1782,9 @@ def forward( if mask.dtype != q.dtype: attn_mask = attn_mask.to(q.dtype) - with torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=True, enable_mem_efficient=True - ): + with torch.backends.cuda.sdp_kernel(enable_flash=True, + enable_math=True, + enable_mem_efficient=True): x = torch.nn.functional.scaled_dot_product_attention( q, k, @@ -1842,17 +1802,14 @@ def forward( if self.h != self.h_k: B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) else: - reshape_q = ( - q.contiguous() - .view(n_batch * self.h, -1, self.d_k) - .transpose(0, 1) - ) # (t1,nh,dk) - B = torch.matmul( - reshape_q, pos_k.transpose(-2, -1) - ) # pos_k: (t1,dk,t2) - B = B.transpose(0, 1).view( - n_batch, self.h, pos_k.size(0), pos_k.size(1) - ) + reshape_q = (q.contiguous().view(n_batch * self.h, -1, + self.d_k).transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul(reshape_q, + pos_k.transpose(-2, + -1)) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), + pos_k.size(1)) scores = A + B else: scores = A @@ -1865,26 +1822,20 @@ def forward( self.attn = attn p_attn = self.dropout(attn) - x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + x = torch.matmul(p_attn.to(v.dtype), + v) # (batch, head, time1, d_k) if pos_v is not None: - reshape_attn = ( - p_attn.contiguous() - .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) - .transpose(0, 1) - ) # (t1, bh, t2) - - attn_v = ( - torch.matmul(reshape_attn, pos_v) - .transpose(0, 1) - .contiguous() - .view(n_batch, self.h, pos_v.size(0), self.d_k) - ) + reshape_attn = (p_attn.contiguous().view( + n_batch * self.h, pos_v.size(0), + pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2) + + attn_v = (torch.matmul(reshape_attn, pos_v).transpose( + 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0), + self.d_k)) x = x + attn_v - x = ( - x.transpose(1, 2) - .contiguous() - .view(n_batch, -1, self.h_k * self.d_k) - ) # (batch, time1, d_model) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h_k * self.d_k) + ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -1896,19 +1847,21 @@ def validate_checkpointing_config(activation_checkpointing): "", "checkpoint", "offload", - ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." + ), "activation_checkpointing has to be a dict or a str in "\ + "('', 'checkpoint', 'offload')." elif isinstance(activation_checkpointing, dict): assert activation_checkpointing.get("module", "transformer") in ( "transformer", "attention", - ), "module in activation_checkpointing has to be in ('transformer', 'attention')." + ), "module in activation_checkpointing has to be in "\ + "('transformer', 'attention')." else: - raise ValueError("activation_checkpointing has to be a str or dict.") + raise ValueError("activation_checkpointing has to be a str"\ + " or dict.") def embedding_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], -) -> Callable: + activation_checkpointing: Union[str, Dict], ) -> Callable: """return encoder embedding activation checkpoint wrapper""" validate_checkpointing_config(activation_checkpointing) @@ -1925,25 +1878,22 @@ def embedding_checkpoint_wrapper( offloading = activation_checkpointing.get("offload", False) if offloading: return offload_wrapper - impl = ( - CheckpointImpl.REENTRANT - if activation_checkpointing.get("reentrant", False) - else CheckpointImpl.NO_REENTRANT - ) + impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( + "reentrant", False) else CheckpointImpl.NO_REENTRANT) return partial(checkpoint_wrapper, checkpoint_impl=impl) return lambda x: x raise ValueError("Invalid activation_checkpointing config") -def attn_checkpointing( - activation_checkpointing: Union[str, Dict], i -) -> Union[str, Dict]: +def attn_checkpointing(activation_checkpointing: Union[str, Dict], + i) -> Union[str, Dict]: """return activation checkpointing config for attention layer""" if isinstance(activation_checkpointing, str): return "" if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get("module", "transformer") + target_layer_cls = activation_checkpointing.get( + "module", "transformer") checkpointing_interval = activation_checkpointing.get("interval", 1) if target_layer_cls == "attention" and i % checkpointing_interval == 0: return activation_checkpointing @@ -1973,8 +1923,10 @@ def repeat(repeat_num, module_gen_fn): """ return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) + def get_offset(input_layer: str, time_reduction: int): - """Get an offset. We will use the offset for determining #frames of a subsampled feature. + """Get an offset. We will use the offset for determining #frames of a + subsampled feature. Args: input_layer (str): Type of an input layer @@ -1984,21 +1936,23 @@ def get_offset(input_layer: str, time_reduction: int): """ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: return 3 - if input_layer in ("conv2d",) and time_reduction == 6: + if input_layer in ("conv2d", ) and time_reduction == 6: return 1 if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: return 7 return 0 + def unfold_tensor(xs_pad, max_seq_len): """ - For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, - this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a + (NT', max_seq_len, D) where T' is T // max_seq_len. Args: xs_pad: N, T, D """ _, _, D = xs_pad.shape - xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T # N x D x 1 x T => N x (D x max_seq_len) x T' xs_pad = F.unfold( xs_pad[..., None, :], @@ -2013,4 +1967,3 @@ def unfold_tensor(xs_pad, max_seq_len): # NT' x max_seq_len x D xs_pad = xs_pad.view(-1, max_seq_len, D) return xs_pad - diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py index 924836eee239..e9a0943a75b0 100644 --- a/vllm/model_executor/models/vision_siglip_navit.py +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -1,4 +1,4 @@ -# coding=utf-8 +# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,45 +14,73 @@ # limitations under the License. """ Siglip model configuration""" +import math import os -from typing import Union +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, + replace_return_docstrings) logger = logging.get_logger(__name__) SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", + "google/siglip-base-patch16-224": + "https://huggingface.co/google/siglip-base-patch16-224/"\ + "resolve/main/config.json", } class SiglipTextConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a - Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip - [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. + This is the configuration class to store the configuration of a + [`SiglipTextModel`]. It is used to instantiate a Siglip text encoder + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the text encoder of the Siglip [google/ + siglip-base-patch16-224](https://huggingface.co/google/siglip-base + -patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by - the `inputs_ids` passed when calling [`SiglipModel`]. + Vocabulary size of the Siglip text model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`SiglipModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + Dimensionality of the "intermediate" (i.e., feed-forward) layer + in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. + Number of attention heads for each attention layer in the + Transformer encoder. max_position_embeddings (`int`, *optional*, defaults to 64): - The maximum sequence length that this model might ever be used with. Typically set this to something large + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + hidden_act (`str` or `function`, *optional*, defaults to + `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. @@ -67,9 +95,11 @@ class SiglipTextConfig(PretrainedConfig): Example: ```python >>> from transformers import SiglipTextConfig, SiglipTextModel - >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 + style configuration >>> configuration = SiglipTextConfig() - >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipTextModel (with random weights) from the + google/siglip-base-patch16-224 style configuration >>> model = SiglipTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -89,14 +119,18 @@ def __init__( layer_norm_eps=1e-6, attention_dropout=0.0, # This differs from `CLIPTokenizer`'s default and from openai/siglip - # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + # See https://github.com/huggingface/transformers/pull/24773# + # issuecomment-1632287538 pad_token_id=1, bos_token_id=49406, eos_token_id=49407, _flash_attn_2_enabled=True, **kwargs, ): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + super().__init__(pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -110,50 +144,64 @@ def __init__( self._flash_attn_2_enabled = _flash_attn_2_enabled @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) # get the text config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["text_config"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", config_dict['model_type'], + cls.model_type) return cls.from_dict(config_dict, **kwargs) class SiglipVisionConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a - Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip - [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + This is the configuration class to store the configuration of a + [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the + model architecture. Instantiating a configuration with the defaults will + yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/ + siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + Dimensionality of the "intermediate" (i.e., feed-forward) layer + in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. + Number of attention heads for each attention layer in the + Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + hidden_act (`str` or `function`, *optional*, defaults to + `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and + `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -161,9 +209,11 @@ class SiglipVisionConfig(PretrainedConfig): Example: ```python >>> from transformers import SiglipVisionConfig, SiglipVisionModel - >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 + style configuration >>> configuration = SiglipVisionConfig() - >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipVisionModel (with random weights) from the + google/siglip-base-patch16-224 style configuration >>> model = SiglipVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -201,54 +251,69 @@ def __init__( self._flash_attn_2_enabled = _flash_attn_2_enabled @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["vision_config"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to " + "instantiate a model of type %s. This is not" + " supported for all configurations of models and can yield" + " errors.", config_dict['model_type'], cls.model_type) return cls.from_dict(config_dict, **kwargs) class SiglipConfig(PretrainedConfig): r""" - [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to - instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. - Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip - [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. + [`SiglipConfig`] is the configuration class to store the configuration of a + [`SiglipModel`]. It is used to instantiate a Siglip model according to the + specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Siglip [google/siglip-base-patch16-224]( + https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. Args: text_config (`dict`, *optional*): - Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + Dictionary of configuration options used to initialize + [`SiglipTextConfig`]. vision_config (`dict`, *optional*): - Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + Dictionary of configuration options used to initialize + [`SiglipVisionConfig`]. kwargs (*optional*): Dictionary of keyword arguments. Example: ```python >>> from transformers import SiglipConfig, SiglipModel - >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 + style configuration >>> configuration = SiglipConfig() - >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> # Initializing a SiglipModel (with random weights) from the + google/siglip-base-patch16-224 style configuration >>> model = SiglipModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig + and a SiglipVisionConfig >>> from transformers import SiglipTextConfig, SiglipVisionConfig >>> # Initializing a SiglipText and SiglipVision configuration >>> config_text = SiglipTextConfig() >>> config_vision = SiglipVisionConfig() - >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + >>> config = SiglipConfig.from_text_vision_configs(config_text, + config_vision) ```""" model_type = "siglip" @@ -258,11 +323,14 @@ def __init__(self, text_config=None, vision_config=None, **kwargs): if text_config is None: text_config = {} - logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + logger.info( + "`text_config` is `None`. Initializing the `SiglipTextConfig`" + " with default values.") if vision_config is None: vision_config = {} - logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + logger.info("`vision_config` is `None`. initializing the " + "`SiglipVisionConfig` with default values.") self.text_config = SiglipTextConfig(**text_config) self.vision_config = SiglipVisionConfig(**vision_config) @@ -270,15 +338,20 @@ def __init__(self, text_config=None, vision_config=None, **kwargs): self.initializer_factor = 1.0 @classmethod - def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + def from_text_vision_configs(cls, text_config: SiglipTextConfig, + vision_config: SiglipVisionConfig, **kwargs): r""" - Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text + model configuration and siglip vision model configuration. Returns: [`SiglipConfig`]: An instance of a configuration object """ - return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + return cls(text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + **kwargs) + # coding=utf-8 # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. @@ -296,35 +369,6 @@ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: # limitations under the License. """ PyTorch Siglip model.""" - -import math -import warnings -from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union -from safetensors.torch import load_model, save_model - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn.init import _calculate_fan_in_and_fan_out - -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, -) - -logger = logging.get_logger(__name__) - _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -334,7 +378,8 @@ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import pad_input # noqa + from flash_attn.bert_padding import index_first_axis, unpad_input # Copied from transformers.models.llama.modeling_llama._get_unpad_data @@ -342,7 +387,8 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -351,8 +397,10 @@ def _get_unpad_data(attention_mask): def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + # Cut & paste from PyTorch official master until it's in a few official + # releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/ + # truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 @@ -367,8 +415,8 @@ def norm_cdf(x): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + l = norm_cdf((a - mean) / std) # noqa + u = norm_cdf((b - mean) / std) # noqa # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. @@ -399,18 +447,21 @@ def norm_cdf(x): tensor.clamp_(min=a, max=b) -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsquently scaled and shifted by the mean and std args. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where + the bounds [a, b] are applied when sampling the normal distribution with + mean=0, std=1.0 and the result is subsequently scaled and shifted by the + mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution @@ -457,24 +508,37 @@ def default_flax_embed_init(tensor): @dataclass -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with +# CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + Base class for vision model's outputs that also contains image embeddings + of the pooling of the last hidden states. Args: - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` + *optional* returned when model is initialized with + `with_projection=True`): + The image embeddings obtained by applying the projection layer to + the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. """ image_embeds: Optional[torch.FloatTensor] = None @@ -484,24 +548,38 @@ class SiglipVisionModelOutput(ModelOutput): @dataclass -# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with +# CLIP->Siglip class SiglipTextModelOutput(ModelOutput): """ - Base class for text model's outputs that also contains a pooling of the last hidden states. + Base class for text model's outputs that also contains a pooling of the + last hidden states. Args: - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` + *optional* returned when model is initialized with + `with_projection=True`): + The text embeddings obtained by applying the projection layer to + model. + the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the + embeddings, if the model has an embedding layer, + one for the + output of each layer) of shape `(batch_size, sequence_length, + hidden_size)`. + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute + the weighted average in the self-attention heads. """ text_embeds: Optional[torch.FloatTensor] = None @@ -511,22 +589,28 @@ class SiglipTextModelOutput(ModelOutput): @dataclass -# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with +# CLIP->Siglip class SiglipOutput(ModelOutput): """ Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `return_loss` is `True`): Contrastive loss for image-text similarity. - logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, + text_batch_size)`): + The scaled dot product scores between `image_embeds` and + `text_embeds`. This represents the image-text similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, + image_batch_size)`): + The scaled dot product scores between `text_embeds` and + `image_embeds`. This represents the text-image similarity scores. text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + The text embeddings obtained by applying the projection layer to + the pooled output of [`SiglipTextModel`]. image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + The image embeddings obtained by applying the projection layer to + the pooled output of [`SiglipVisionModel`]. text_model_output(`BaseModelOutputWithPooling`): The output of the [`SiglipTextModel`]. vision_model_output(`BaseModelOutputWithPooling`): @@ -543,12 +627,13 @@ class SiglipOutput(ModelOutput): def to_tuple(self) -> Tuple[Any]: return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) + self[k] if k not in ["text_model_output", "vision_model_output" + ] else getattr(self, k).to_tuple() + for k in self.keys()) class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config @@ -567,17 +652,21 @@ def __init__(self, config: SiglipVisionConfig): self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) - def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + def forward(self, pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor) -> torch.Tensor: batch_size = pixel_values.size(0) patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) - max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, \ + max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) position_ids = torch.full( size=( batch_size, @@ -590,13 +679,20 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h, nb_patches_h) - fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w, nb_patches_w) + fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h, + nb_patches_h) + fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w, + nb_patches_w) - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) @@ -605,19 +701,24 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with +# CLIP->Siglip class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) - self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, + embed_dim) - # position_ids (1, len position emb) is contiguous in memory and exported when serialized + # position_ids (1, len position emb) is contiguous in memory and + # exported when serialized self.register_buffer( - "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False - ) + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False) def forward( self, @@ -625,7 +726,8 @@ def forward( position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + seq_length = input_ids.shape[ + -1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] @@ -651,9 +753,8 @@ def __init__(self, config): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) + f"embed_dim must be divisible by num_heads (got `embed_dim`:" + f" {self.embed_dim} and `num_heads`: {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -667,7 +768,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() @@ -676,36 +778,47 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) + f"Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" - ) + raise ValueError(f"Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}, " + f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f"`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) @@ -717,9 +830,11 @@ def forward( class SiglipFlashAttention2(SiglipAttention): """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. + Llama flash attention module. This module inherits from `LlamaAttention` as + the weights of the module stays untouched. The only required change would + be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any + of them. """ def __init__(self, *args, **kwargs): @@ -735,7 +850,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -747,21 +863,21 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) - # if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # TODO: These transpose are quite inefficient but Flash Attention + # requires the layout [batch_size, sequence_length, num_heads, + # head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -769,11 +885,12 @@ def forward( dropout_rate = self.dropout if self.training else 0.0 - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct + # dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to + # not cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: @@ -786,20 +903,24 @@ def forward( target_dtype = self.q_proj.weight.dtype logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + "The input hidden states seems to be silently casted in " + "float32, this might be related to the fact you have upcasted " + "embedding or layer norm layers in float32. We will cast " + f"back the input in {target_dtype}.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -807,12 +928,19 @@ def forward( return attn_output, attn_weights - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. + Calls the forward method of Flash Attention - if the input hidden + states contain at least one padding token first unpad the input, + then computes the attention scores and pad the final attention + scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API @@ -821,23 +949,29 @@ def _flash_attention_forward( value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. + The padding mask - corresponds to a tensor of size + `(batch_size, seq_len)` where 0 stands for the position + of padding tokens and 1 for the position of non-padding + tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + The scaling of QK^T before applying softmax. Default to 1 / + sqrt(head_dim) """ - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + # TODO: Remove the `query_length != 1` check once Flash Attention for + # RoCm is bumped to 2.1. For details, please see the comment in + # LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, indices_q, cu_seq_lens, \ + max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, + query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -855,28 +989,34 @@ def _flash_attention_forward( causal=causal, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -890,7 +1030,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \ + unpad_input(query_layer, attention_mask) return ( query_layer, @@ -904,6 +1045,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): + def __init__(self, config): super().__init__() self.config = config @@ -918,19 +1060,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with +# CLIP->Siglip class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = ( - SiglipAttention(config) - if not getattr(config, "_flash_attn_2_enabled", False) - else SiglipFlashAttention2(config) - ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = (SiglipAttention(config) if + not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) def forward( self, @@ -943,10 +1087,12 @@ def forward( hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where + padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under returned tensors for + more detail. """ residual = hidden_states @@ -963,18 +1109,18 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) return outputs class SiglipPreTrainedModel(PreTrainedModel): """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. """ config_class = SiglipConfig @@ -985,12 +1131,10 @@ def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, SiglipConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + width = (self.config.vision_config.hidden_size if isinstance( + self.config, SiglipConfig) else self.config.hidden_size) + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): @@ -1025,98 +1169,125 @@ def _init_weights(self, module): SIGLIP_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. + This model inherits from [`PreTrainedModel`]. Check the superclass + documentation for the generic methods the library implements for all + its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/ + stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation + for all matter related to general usage and behavior. Parameters: - config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + config ([`SiglipConfig`]): Model configuration class with all the + parameters of the model. + Initializing with a config file does not load the weights + associated with the model, only the configuration. Check out + the [`~PreTrainedModel.from_pretrained`] method to load the + model weights. """ SIGLIP_TEXT_INPUTS_DOCSTRING = r""" Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length) + `): + Indices of input sequence tokens in the vocabulary. Padding will + be ignored by default should you provide it. + Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + position_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. """ SIGLIP_VISION_INPUTS_DOCSTRING = r""" Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + pixel_values (`torch.FloatTensor` of shape `(batch_size, + num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you + provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] + for details. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. """ SIGLIP_INPUTS_DOCSTRING = r""" Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + input_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding + will be ignored by default should you provide it. + Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)` + , *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + position_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + pixel_values (`torch.FloatTensor` of shape `(batch_size, + num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you + provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] + for details. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. """ -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with +# CLIP->Siglip class SiglipEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`SiglipEncoderLayer`]. Args: config: SiglipConfig """ @@ -1124,7 +1295,9 @@ class SiglipEncoder(nn.Module): def __init__(self, config: SiglipConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) self.gradient_checkpointing = False # Ignore copy @@ -1138,29 +1311,38 @@ def forward( ) -> Union[Tuple, BaseModelOutput]: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + attention_mask (`torch.Tensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under returned tensors for + more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1168,7 +1350,7 @@ def forward( hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, @@ -1186,31 +1368,36 @@ def forward( hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) - self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.final_layer_norm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, embed_dim) @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1223,11 +1410,14 @@ def forward( r""" Returns: """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states \ + is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") @@ -1235,13 +1425,17 @@ def forward( input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids) - # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # note: SigLIP's text model does not use a causal mask, unlike the + # original CLIP model. # expand attention_mask if attention_mask is not None: - # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + # [batch_size, seq_len] -> + # [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1291,7 +1485,8 @@ def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipTextConfig) def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1306,15 +1501,21 @@ def forward( Examples: ```python >>> from transformers import AutoTokenizer, SiglipTextModel - >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> model = SiglipTextModel. + from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer. + from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" + as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], + padding="max_length", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + >>> pooled_output = outputs.pooler_output # pooled (EOS token) + states ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict return self.text_model( input_ids=input_ids, @@ -1327,6 +1528,7 @@ def forward( class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config @@ -1334,11 +1536,13 @@ def __init__(self, config: SiglipVisionConfig): self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) self.head = SiglipMultiheadAttentionPoolingHead(config) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) def forward( self, pixel_values, @@ -1350,11 +1554,13 @@ def forward( r""" Returns: """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None\ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: @@ -1368,20 +1574,22 @@ def forward( device=pixel_values.device, ) - hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + # So when the `patch_attention_mask` is full of 1s (i.e. attending + # to the whole sequence), avoiding passing the attention_mask, which + # is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): - attention_mask=None + attention_mask = None else: - attention_mask = ( - _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - if not self.config._flash_attn_2_enabled - else patch_attention_mask - ) + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled else + patch_attention_mask) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1417,17 +1625,20 @@ def __init__(self, config: SiglipVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state, attention_mask): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) - hidden_state = self.attention( - query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask - )[0] + hidden_state = self.attention(query=probe, + key=hidden_state, + value=hidden_state, + key_padding_mask=~attention_mask)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) @@ -1456,7 +1667,8 @@ def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) def forward( self, pixel_values, @@ -1472,16 +1684,20 @@ def forward( >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, SiglipVisionModel - >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> model = SiglipVisionModel.from_pretrained( + "google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") + >>> url = + "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, @@ -1500,16 +1716,14 @@ def __init__(self, config: SiglipConfig): super().__init__(config) if not isinstance(config.text_config, SiglipTextConfig): - raise ValueError( - "config.text_config is expected to be of type SiglipTextConfig but is of type" - f" {type(config.text_config)}." - ) + raise ValueError("config.text_config is expected to be of type " + f"SiglipTextConfig but is of type" + f" {type(config.text_config)}.") if not isinstance(config.vision_config, SiglipVisionConfig): - raise ValueError( - "config.vision_config is expected to be of type SiglipVisionConfig but is of type" - f" {type(config.vision_config)}." - ) + raise ValueError("config.vision_config is expected to be of type " + "SiglipVisionConfig but is of type" + f" {type(config.vision_config)}.") text_config = config.text_config vision_config = config.vision_config @@ -1535,25 +1749,34 @@ def get_text_features( ) -> torch.FloatTensor: r""" Returns: - text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by - applying the projection layer to the pooled output of [`SiglipTextModel`]. + text_features (`torch.FloatTensor` of shape `(batch_size, + output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output + of [`SiglipTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, AutoModel >>> import torch - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> model = AutoModel.from_pretrained( + "google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained( + "google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's + how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], + padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Use SigLIP model's config for some fields (if specified) instead + # of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None\ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict text_outputs = self.text_model( input_ids=input_ids, @@ -1578,8 +1801,9 @@ def get_image_features( ) -> torch.FloatTensor: r""" Returns: - image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`SiglipVisionModel`]. + image_features (`torch.FloatTensor` of shape `(batch_size, + output_dim`): The image embeddings obtained by applying the + projection layer to the pooled output of [`SiglipVisionModel`]. Examples: ```python >>> from PIL import Image @@ -1587,19 +1811,23 @@ def get_image_features( >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" - # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Use SiglipModel's config for some fields (if specified) instead + # of those of vision & text components. + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -1613,7 +1841,8 @@ def get_image_features( return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + @replace_return_docstrings(output_type=SiglipOutput, + config_class=SiglipConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1634,25 +1863,32 @@ def forward( >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] - >>> # important: we pass `padding=max_length` since the model was trained with this - >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + >>> # important: we pass `padding=max_length` since the model was + trained with this + >>> inputs = processor(text=texts, images=image, + padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image - >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> probs = torch.sigmoid(logits_per_image) # these are the + probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Use SigLIP model's config for some fields (if specified) instead of + # those of vision & text components. + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -1674,11 +1910,13 @@ def forward( text_embeds = text_outputs[1] # normalized features - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + image_embeds = image_embeds / image_embeds.norm( + p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits - logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_text = torch.matmul(text_embeds, image_embeds.t( + )) * self.logit_scale.exp() + self.logit_bias logits_per_image = logits_per_text.t() loss = None @@ -1686,8 +1924,9 @@ def forward( raise NotImplementedError("SigLIP loss to be implemented") if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output + output = (logits_per_image, logits_per_text, text_embeds, + image_embeds, text_outputs, vision_outputs) + return ((loss, ) + output) if loss is not None else output return SiglipOutput( loss=loss, @@ -1714,8 +1953,7 @@ def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): model_config = SiglipVisionConfig( **siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, - **kwargs - ) + **kwargs) vision_model = SiglipVisionModel(model_config).vision_model From 86295b3377920aa9594336d863672c123d70ea08 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 15:47:02 -0800 Subject: [PATCH 17/27] restore requirements-test.txt Signed-off-by: Congcong Chen --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index e5bf67e099e4..d6e21452191a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -696,4 +696,4 @@ xxhash==3.5.0 yarl==1.17.1 # via aiohttp zstandard==0.23.0 - # via lm-eval + # via lm-eval \ No newline at end of file From bbdcfb7508900d2dbe521490d48c7c3b7a14ca69 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 16:49:39 -0800 Subject: [PATCH 18/27] remove hard dependency on flash-attn Signed-off-by: Congcong Chen --- .../models/vision_siglip_navit.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py index e9a0943a75b0..cd70d660f209 100644 --- a/vllm/model_executor/models/vision_siglip_navit.py +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -33,10 +33,13 @@ BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import (ModelOutput, add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, + add_start_docstrings_to_model_forward, logging, replace_return_docstrings) +from vllm.platforms import _Backend + +from .vision import get_vit_attn_backend + logger = logging.get_logger(__name__) SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -376,11 +379,6 @@ def from_text_vision_configs(cls, text_config: SiglipTextConfig, # See all SigLIP models at https://huggingface.co/models?filter=siglip ] -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import pad_input # noqa - from flash_attn.bert_padding import index_first_axis, unpad_input - # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): @@ -841,6 +839,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False # Hack to make sure we don't use a causal mask + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend != _Backend.FLASH_ATTN: + raise RuntimeError( + "Phi-4-multimodal-instruct model does not support"\ + " {self.attn_backend} backend now." + ) + def forward( self, hidden_states: torch.Tensor, @@ -959,6 +965,8 @@ def _flash_attention_forward(self, The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa # TODO: Remove the `query_length != 1` check once Flash Attention for # RoCm is bumped to 2.1. For details, please see the comment in @@ -1003,6 +1011,7 @@ def _flash_attention_forward(self, def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + from flash_attn.bert_padding import index_first_axis, unpad_input indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape From 28de545611618ebc4f1bc41baf823a65f28cd103 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 17:28:26 -0800 Subject: [PATCH 19/27] Register test and add model info to supported_model.md Signed-off-by: Congcong Chen --- docs/source/models/supported_models.md | 11 +++++++++++ tests/models/registry.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 0e93a15b84fc..bbc265928233 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -821,6 +821,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Phi4MMForCausalLM` + * Phi-4-multimodal + * T + I+ + A+ + * `microsoft/Phi-4-multimodal-instruct`, etc. + * ✅︎ + * + * ✅︎ - * `PixtralForConditionalGeneration` * Pixtral * T + I+ @@ -896,6 +903,10 @@ Currently the PaliGemma model series is implemented without PrefixLM attention m To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`. ::: +:::{note} +Phi4MMForCausalLM doesn't support text + image + audio input. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/tests/models/registry.py b/tests/models/registry.py index b5ded20c5af5..3fe5f4318662 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -270,6 +270,8 @@ def check_available_online( extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True), + "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", + trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 tokenizer_mode="mistral"), "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", From 5327e89b932550d5feb716d51ae6b77182020215 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 17:33:24 -0800 Subject: [PATCH 20/27] restore requirements-test.txt Signed-off-by: Congcong Chen --- requirements-test.txt | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index d6e21452191a..cfc000a1b09f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -23,10 +23,6 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator -async-timeout==4.0.3 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -120,10 +116,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -552,7 +544,9 @@ sentence-transformers==3.2.1 sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 - # via pytablewriter + # via + # pytablewriter + # torch six==1.16.0 # via # python-dateutil @@ -597,12 +591,6 @@ timm==1.0.11 # via -r requirements-test.in tokenizers==0.21.0 # via transformers -toml==0.10.2 - # via datamodel-code-generator -tomli==2.2.1 - # via - # black - # pytest torch==2.5.1 # via # -r requirements-test.in @@ -663,17 +651,13 @@ typepy==1.3.2 # tabledata typing-extensions==4.12.2 # via - # anyio # bitsandbytes - # black # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch tzdata==2024.2 # via pandas From 454d4a7bba64705d2040ab1eaa967bb828c13e64 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 3 Mar 2025 17:46:41 -0800 Subject: [PATCH 21/27] restore requirements-test.txt Signed-off-by: Congcong Chen --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index cfc000a1b09f..f5722c82e201 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -680,4 +680,4 @@ xxhash==3.5.0 yarl==1.17.1 # via aiohttp zstandard==0.23.0 - # via lm-eval \ No newline at end of file + # via lm-eval From 1bb5750b11bc8b0e86a53ba21375b840999ad0be Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 4 Mar 2025 11:53:22 -0800 Subject: [PATCH 22/27] Add text-only version of Phi-4-mini to the supported_models page per request Signed-off-by: Congcong Chen --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index bbc265928233..9480daaae321 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -375,7 +375,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Phi3ForCausalLM` * Phi-4, Phi-3 - * `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. + * `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. * ✅︎ * ✅︎ - * `Phi3SmallForCausalLM` From a2cc774d800ef8947cbf163845167759c7558e28 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 4 Mar 2025 12:02:49 -0800 Subject: [PATCH 23/27] Minor update to supported_models.md Signed-off-by: Congcong Chen --- docs/source/models/supported_models.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 9480daaae321..afd850d0ce42 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -823,7 +823,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Phi4MMForCausalLM` * Phi-4-multimodal - * T + I+ + A+ + * T + I+ / T + A+ / I+ + A+ * `microsoft/Phi-4-multimodal-instruct`, etc. * ✅︎ * @@ -903,10 +903,6 @@ Currently the PaliGemma model series is implemented without PrefixLM attention m To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`. ::: -:::{note} -Phi4MMForCausalLM doesn't support text + image + audio input. -::: - ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. From 04609123f62d85acd470acd89510114b0d8b1d1f Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 4 Mar 2025 13:25:42 -0800 Subject: [PATCH 24/27] Update supported_models.md Signed-off-by: Congcong Chen --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index afd850d0ce42..dfdebf42b19f 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -827,7 +827,7 @@ See [this page](#generative-models) for more information on how to use generativ * `microsoft/Phi-4-multimodal-instruct`, etc. * ✅︎ * - * ✅︎ + * - * `PixtralForConditionalGeneration` * Pixtral * T + I+ From 08c845a91e8d15595f40354b4d05271b0451ba76 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 4 Mar 2025 13:26:58 -0800 Subject: [PATCH 25/27] delete the testing script Signed-off-by: Congcong Chen --- examples/offline_inference_phi3o.py | 549 ---------------------------- 1 file changed, 549 deletions(-) delete mode 100644 examples/offline_inference_phi3o.py diff --git a/examples/offline_inference_phi3o.py b/examples/offline_inference_phi3o.py deleted file mode 100644 index 6166254db213..000000000000 --- a/examples/offline_inference_phi3o.py +++ /dev/null @@ -1,549 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Implements a simple offline inference script for the Phi 3.5 Speech model. -# Code implemented by Jacob Platin (jacobplatin@microsoft.com) - -import soundfile - -from vllm import LLM, SamplingParams -from vllm.lora.request import LoRARequest -from vllm.multimodal.utils import fetch_image -from vllm.utils import FlexibleArgumentParser - - -def main_pure_text(args: dict) -> None: - """ - Main function for the offline inference script. - """ - llm = LLM(model=args.model_path, - trust_remote_code=True, - enforce_eager=True) - user_prompt = '<|user|>\n' - assistant_prompt = '<|assistant|>\n' - prompt_suffix = '<|end|>\n' - prompt = f'{user_prompt}what is the answer for 1+1? Explain'\ - f' it.{prompt_suffix}{assistant_prompt}' - print(f'>>> Prompt\n{prompt}') - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = {"prompt": prompt} - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate(generate_args, sampling_params=sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_speech(args: dict, activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - wav_paths = [args.wav_path] - llm = LLM(model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - limit_mm_per_prompt={"audio": len(wav_paths)}, - max_loras=5) - - # assert len(wav_paths) == 1, "Only support single audio files for now!" - - prompt = "Generate a comprehensive text transcription of the "\ - "spoken content." - placeholders = "\n".join(f"<|audio_{i}|>" - for i in range(1, - len(wav_paths) + 1)) - prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" - - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = { - "prompt": prompt, - "multi_modal_data": { - "audio": [soundfile.read(wav_path) for wav_path in wav_paths] - } - } - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=[LoRARequest("speech_adapter", 3, args.speech_lora_path)] - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_speech_batch(args: dict, - activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - wav_paths = [args.wav_path, args.wav_path] - - llm = LLM(model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - limit_mm_per_prompt={"audio": len(wav_paths)}, - max_loras=5) - - # assert len(wav_paths) == 1, "Only support single audio files for now!" - - prompt = "Based on the attached audio, generate a comprehensive text "\ - "transcription of the spoken content." - placeholders = "\n".join(f"<|audio_{i}|>" - for i in range(1, - len(wav_paths) + 1)) - prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" - - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = [ - { - "prompt": prompt, - "multi_modal_data": { - "audio": [soundfile.read(wav_path) for wav_path in wav_paths] - } - }, - { - "prompt": prompt, - "multi_modal_data": { - "audio": [soundfile.read(wav_path) for wav_path in wav_paths] - } - }, - ] - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=LoRARequest("speech_adapter", 3, args.speech_lora_path) - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_vision(args: dict, activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - image_urls = [args.image_url] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - max_loras=5, - # max_model_len=4096, - # max_num_seqs=2, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - # prompt = "what's the traffic sign in the image" - prompt = "What is shown in this image?" - - placeholders = "\n".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" - - image_data = [fetch_image(url) for url in image_urls] - - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = { - "prompt": prompt, - "multi_modal_data": { - "image": image_data, - }, - } - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=[LoRARequest("vision_adapter", 3, args.vision_lora_path)] - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_vision_batch(args: dict, - activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - image_urls = [ - args.image_url, - "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg" - ] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - max_loras=5, - - # max_model_len=4096, - # max_num_seqs=2, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - # prompt = "what's the traffic sign in the image" - prompt = "What is shown in this image?" - - placeholders = "\n".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n{prompt}<|end|>\n<|assistant|>\n" - - # image_data=[fetch_image(url) for url in image_urls] - - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = [ - { - "prompt": prompt, - "multi_modal_data": { - "image": [ - fetch_image(url) for url in [ - "https://www.ilankelman.org/stopsigns/australia.jpg", - "https://alinasayre.com/wp-content/uploads/2013/10/"\ - "d67cd-dsc01646.jpg" - ] - ], - }, - }, - { - "prompt": prompt, - "multi_modal_data": { - "image": [fetch_image(url) for url in image_urls], - }, - }, - ] - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=LoRARequest("vision_adapter", 3, args.vision_lora_path) - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_vision_speech(args: dict, - activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - image_urls = [args.image_url] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - max_loras=5, - - # max_model_len=4096, - # max_num_seqs=5, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - - prompt = "" - - placeholders = "\n".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}<|end|>"\ - "\n<|assistant|>\n" - - image_data = [fetch_image(url) for url in image_urls] - - wav_paths = [ - "/scratch/turing_westus3_prm_data/users/congcongchen/MoE_2/hf-models"\ - "/phio/examples/what_is_the_traffic_sign_in_the_image.wav" - ] - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = { - "prompt": prompt, - "multi_modal_data": { - "image": image_data, - "audio": [soundfile.read(wav_path) for wav_path in wav_paths], - }, - } - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=[LoRARequest("vision_adapter", 3, args.vision_lora_path)] - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -def main_with_lora_vision_speech_batch(args: dict, - activate_lora_request=True) -> None: - """ - Main function for the offline inference script. - """ - image_urls = [ - args.image_url, - "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg" - ] - wav_paths = [args.wav_path] - llm = LLM( - model=args.model_path, - trust_remote_code=True, - enable_lora=activate_lora_request, - enforce_eager=True, - max_lora_rank=512, - lora_extra_vocab_size=0, - max_loras=5, - - # max_model_len=40960, - # max_num_seqs=5, - limit_mm_per_prompt={ - "image": len(image_urls), - "audio": len(wav_paths) - }, - ) - - prompt = "try your best to answer the question" - - placeholders = "\n".join(f"<|image_{i}|>" - for i, _ in enumerate(image_urls, start=1)) - prompt = f"<|user|>\n{placeholders}\n<|audio_1|>\n{prompt}"\ - "<|end|>\n<|assistant|>\n" - - # image_data=[fetch_image(url) for url in image_urls] - - # NOTE: soundfile.read will return the audio feature and the sampling rate - generate_args = [ - { - "prompt": prompt, - "multi_modal_data": { - "image": [fetch_image(url) for url in image_urls], - "audio": [soundfile.read(wav_path) for wav_path in wav_paths], - }, - }, - { - "prompt": prompt, - "multi_modal_data": { - "image": [ - fetch_image(url) for url in [ - "https://alinasayre.com/wp-content/uploads/"\ - "2013/10/d67cd-dsc01646.jpg", - "https://alinasayre.com/wp-content/uploads/"\ - "2012/01/c3a7c-dsc01668.jpg" - ] - ], - "audio": [soundfile.read(wav_path) for wav_path in wav_paths], - }, - }, - ] - # NOTE: you should use the following settings to ensure parity in HF - # generate_ids = model.generate( - # **inputs, - # top_p=1, - # max_new_tokens=1200, - # temperature=0, - # use_cache=False, - # min_p=0, - # top_k=-1, - # ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=1200, - ) - - outputs = llm.generate( - generate_args, - sampling_params=sampling_params, - lora_request=LoRARequest("vision_adapter", 3, args.vision_lora_path) - if activate_lora_request else None) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Generated text: {generated_text!r}\n\n") - - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="Demo on using vLLM for offline inference with " - "vision language models that support multi-image input") - parser.add_argument( - "--model-path", - "-p", - type=str, - default= - "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm", - help="Path to the (HuggingFace) model checkpoint.", - ) - - parser.add_argument( - "--vision-lora-path", - "-v", - type=str, - default= - "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/vision-lora", - help="Path to the (HuggingFace) vision lora model checkpoint.", - ) - - parser.add_argument( - "--speech-lora-path", - "-s", - type=str, - default= - "/scratch/turing_westus3_prm_data/users/congcongchen/phi4-mini-mm/speech-lora", - help="Path to the (HuggingFace) speech lora model checkpoint.", - ) - - parser.add_argument( - "--wav-path", - "-w", - type=str, - default= - "/scratch/turing_westus3_prm_data/users/congcongchen/30s_test_6.wav", - help="Path to the audio file.", - ) - - parser.add_argument( - "--image-url", - "-i", - type=str, - default= - "https://alinasayre.com/wp-content/uploads/2013/10/d67cd-dsc01646.jpg", - ) - - parser.add_argument( - "--test-type", - "-t", - type=str, - default="speech_language_with_lora", - ) - - args = parser.parse_args() - ##### Language Only ##### - test_type = args.test_type - if test_type == "language_only": - main_pure_text(args) - ##### Speech + Language ##### - elif test_type == "speech_language_with_lora": - main_with_lora_speech(args) - elif test_type == "speech_language_with_lora_batch": - main_with_lora_speech_batch(args) - elif test_type == "speech_language_without_lora": - main_with_lora_speech(args, activate_lora_request=False) - ##### Vision + Language ##### - elif test_type == "vision_language_with_lora": - main_with_lora_vision(args) - elif test_type == "vision_language_with_lora_batch": - main_with_lora_vision_batch(args) - elif test_type == "vision_language_without_lora": - main_with_lora_vision(args, activate_lora_request=False) - ##### Vision + Speech + Language ##### - elif test_type == "vision_speech_language_with_lora": - main_with_lora_vision_speech(args) - elif test_type == "vision_speech_language_with_lora_batch": - main_with_lora_vision_speech_batch(args) - elif test_type == "vision_speech_language_without_lora": - main_with_lora_vision_speech(args, activate_lora_request=False) From 72ddbb4d839fcda082e0a67cb7c7ecb68f766f80 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 4 Mar 2025 15:51:58 -0800 Subject: [PATCH 26/27] Print errors instead of throwing the RuntimeError since CI environment does not have flash_attn installed Signed-off-by: Congcong Chen --- vllm/model_executor/models/vision_siglip_navit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py index cd70d660f209..599f1c26917d 100644 --- a/vllm/model_executor/models/vision_siglip_navit.py +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -842,10 +842,11 @@ def __init__(self, *args, **kwargs): # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend != _Backend.FLASH_ATTN: - raise RuntimeError( - "Phi-4-multimodal-instruct model does not support"\ - " {self.attn_backend} backend now." - ) + # Only print out errors for now to make ci/pr/basic-models-test + # happy since the testing environment does not have flash_attn + # installed. + logger.error("Phi-4-multimodal-instruct model does not support "\ + "%s backend now.", self.attn_backend) def forward( self, From 77f4edcd9aad59c32360a3546b943f2f92a9b217 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 4 Mar 2025 15:58:43 -0800 Subject: [PATCH 27/27] update attn_backend detection Signed-off-by: Roger Wang --- vllm/model_executor/models/vision_siglip_navit.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py index 599f1c26917d..3a9597a845ff 100644 --- a/vllm/model_executor/models/vision_siglip_navit.py +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -839,15 +839,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False # Hack to make sure we don't use a causal mask - # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) - if self.attn_backend != _Backend.FLASH_ATTN: - # Only print out errors for now to make ci/pr/basic-models-test - # happy since the testing environment does not have flash_attn - # installed. - logger.error("Phi-4-multimodal-instruct model does not support "\ - "%s backend now.", self.attn_backend) - def forward( self, hidden_states: torch.Tensor, @@ -1960,6 +1951,11 @@ def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): "patch_size": 14, } + # Detect attention implementation. + attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if attn_backend != _Backend.FLASH_ATTN: + _flash_attn_2_enabled = False + model_config = SiglipVisionConfig( **siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled,