diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 66a7554da846..14a59953ef48 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -262,6 +262,255 @@ def get_mm_max_tokens_per_item( Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP. ::: +:::: + +::::{tab-item} Non-consecutive feature tokens: Fuyu +:sync: fuyu + +Looking at the code of HF's `FuyuForCausalLM`: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322 +if image_patches is not None and past_key_values is None: + patch_embeddings = [ + self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) + .squeeze(0) + .to(inputs_embeds.device) + for patch in image_patches + ] + inputs_embeds = self.gather_continuous_embeddings( + word_embeddings=inputs_embeds, + continuous_embeddings=patch_embeddings, + image_patch_input_indices=image_patches_indices, + ) +``` + +The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`, +which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`. + +Unlike LLaVA, Fuyu does not define the number of patches inside the modeling file. Where can we get more information? +Considering that the model input comes from the output of `FuyuProcessor`, let's **look at the preprocessing files**. + +The image outputs are obtained by calling `FuyuImageProcessor.preprocess` and then +`FuyuImageProcessor.preprocess_with_tokenizer_info` inside `FuyuProcessor`. + +In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`, +returning the dimensions after resizing (but before padding) as metadata. + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544 +image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"]) +batch_images = image_encoding["images"] +image_unpadded_heights = image_encoding["image_unpadded_heights"] +image_unpadded_widths = image_encoding["image_unpadded_widths"] + +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L +if do_resize: + batch_images = [ + [self.resize(image, size=size, input_data_format=input_data_format) for image in images] + for images in batch_images + ] + +image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] +image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] +image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] + +if do_pad: + batch_images = [ + [ + self.pad_image( + image, + size=size, + mode=padding_mode, + constant_values=padding_value, + input_data_format=input_data_format, + ) + for image in images + ] + for images in batch_images + ] +``` + +In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425 +model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=tensor_batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=image_placeholder_id, + image_newline_id=image_newline_id, + variable_sized=True, +) + +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658 +image_height, image_width = image.shape[1], image.shape[2] +if variable_sized: # variable_sized=True + new_h = min( + image_height, + math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height, + ) + new_w = min( + image_width, + math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width, + ) + image = image[:, :new_h, :new_w] + image_height, image_width = new_h, new_w + +num_patches = self.get_num_patches(image_height=image_height, image_width=image_width) +tensor_of_image_ids = torch.full( + [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device +) +patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) +assert num_patches == patches.shape[0] +``` + +The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562 +patch_size = patch_size if patch_size is not None else self.patch_size +patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] + +if image_height % patch_height != 0: + raise ValueError(f"{image_height=} must be divisible by {patch_height}") +if image_width % patch_width != 0: + raise ValueError(f"{image_width=} must be divisible by {patch_width}") + +num_patches_per_dim_h = image_height // patch_height +num_patches_per_dim_w = image_width // patch_width +num_patches = num_patches_per_dim_h * num_patches_per_dim_w +``` + +We can calculate this in vLLM using this code: + +```python +def get_num_image_patches( + self, + *, + image_width: int, + image_height: int, +) -> int: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return ncols * nrows +``` + +These image patches correspond to placeholder tokens (`|SPEAKER|`). However, the processor also +inserts newline tokens (`|NEWLINE|`) as shown here: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L654-L670 +tensor_of_image_ids = torch.full( + [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device +) +patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) +assert num_patches == patches.shape[0] + +if variable_sized: + # Now terminate each line with |NEWLINE|. + tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width) + newline_ids = torch.full( + [tensor_of_image_ids.shape[0], 1], + image_newline_id, + dtype=torch.int32, + device=image_input.device, + ) + tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1) + tensor_of_image_ids = tensor_of_image_ids.reshape(-1) +``` + +So, the layout of tokens for an image is: + +``` +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +... +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +``` + +This makes the placeholder tokens non-consecutive in the prompt. +Since vLLM requires the feature tokens to be consecutive, **we also treat the newline tokens as feature tokens**. + +So overall, the total number of feature tokens is + +```python +def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, +) -> int: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return (ncols + 1) * nrows +``` + +To calculate the maximum number of image tokens, recall that input images are first resized +to fit within `image_processor.size`. The maximum possible dimensions of the image before +being converted into patches is therefore equal to `image_processor.size`. + +```python +def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size["width"], + height=image_processor.size["height"]) + +def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) +``` + +And thus, we can override the method as: + +```python +def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], +) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} +``` + +:::{note} +Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) returns `ncols` and `nrows` directly instead of the total token count. +This is because `ncols` and `nrows` are used to specify the layout of the feature tokens (as shown in Step 4 of this guide). +::: + :::: ::::: @@ -282,7 +531,8 @@ on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max ::::{tab-set} :::{tab-item} Basic example: LLaVA :sync: llava -Making use of the `get_image_size_with_most_features` method implemented in the previous section: + +Making use of the `get_image_size_with_most_features` method implemented in Step 2: ```python def get_dummy_processor_inputs( @@ -312,6 +562,39 @@ def get_dummy_processor_inputs( ``` ::: + +:::{tab-item} No input placeholders: Fuyu +:sync: fuyu + +Fuyu does not expect image placeholders in the inputs to HF processor, so +the dummy prompt text is empty regardless of the number of images. +Otherwise, the logic of this method is very similar to LLaVA: + +```python +def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], +) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) +``` + +::: + :::: ## 4. Specify processing details @@ -325,40 +608,28 @@ to fill in the missing details about HF processing. ### Multi-modal fields -Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to +Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items. :::::{tab-set} ::::{tab-item} Basic example: LLaVA :sync: llava -Looking at the model's `forward` method: +The output of `CLIPImageProcessor` is a simple tensor with shape +`(num_images, num_channels, image_height, image_width)`: ```python -# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L387-L404 -def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, -) -> Union[Tuple, LlavaCausalLMOutputWithPast]: -``` - -The only related keyword argument is `pixel_values` which directly corresponds to input images. -The shape of `pixel_values` is `(N, C, H, W)` where `N` is the number of images. -So, we override the method as follows: +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345 +images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images +] + +data = {"pixel_values": images} +return BatchFeature(data=data, tensor_type=return_tensors) +``` + +So, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows: ```python def _get_mm_fields_config( @@ -377,11 +648,83 @@ pre-computed image embeddings, which can be passed to be model via the `image_em ::: :::: + +::::{tab-item} With postprocessing: Fuyu +:sync: fuyu + +The `image_patches` output of `FuyuImageProcessor.preprocess_with_tokenizer_info` concatenates +the patches from each image belonging to an item in the batch: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L673-L679 + image_input_ids.append(tensor_of_image_ids) + image_patches.append(patches) + else: + image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device)) + +batch_image_input_ids.append(image_input_ids) +batch_image_patches.append(image_patches) +``` + +The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore +`(1, num_images, num_patches, patch_width * patch_height * num_channels)`. + +In order to support the use of {func}`MultiModalFieldConfig.batched` like in LLaVA, +we remove the extra batch dimension by overriding {meth}`BaseMultiModalProcessor._call_hf_processor`: + +```python +def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], +) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + image_patches = processed_outputs.get("image_patches") + if image_patches is not None: + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, Pn, Px * Py * C) + # New output: (num_images, Pn, Px * Py * C) + assert (isinstance(image_patches, list) + and len(image_patches) == 1) + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + + processed_outputs["image_patches"] = image_patches[0] + + return processed_outputs +``` + +:::{note} +Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling +for text-only inputs to prevent unnecessary warnings from HF processor. +::: + +This lets us override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` as follows: + +```python +def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], +) -> Mapping[str, MultiModalFieldConfig]: + return dict(image_patches=MultiModalFieldConfig.batched("image")) +``` + +:::: + ::::: ### Prompt replacements -Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to +Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances. Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace @@ -402,7 +745,7 @@ for sample in text: ``` It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). -Based on this, we override the method as follows: +Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows: ```python def _get_prompt_replacements( @@ -435,6 +778,159 @@ def _get_prompt_replacements( ``` ::: + +:::{tab-item} Handling additional tokens: Fuyu +:sync: fuyu + +Recall the layout of feature tokens from Step 2: + +``` +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +... +|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| +``` + +We define a helper function to return `ncols` and `nrows` directly: + +```python +def get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, +) -> tuple[int, int]: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return ncols, nrows +``` + +Based on this, we can initially define our replacement tokens as: + +```python +def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + # `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|` + # `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|` + return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +``` + +However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called, +a BOS token (``) is also added to the promopt: + +```python +# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435 +model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=tensor_batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=image_placeholder_id, + image_newline_id=image_newline_id, + variable_sized=True, +) +prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( + tokenizer=self.tokenizer, + prompts=prompts, + scale_factors=scale_factors, + max_tokens_to_generate=self.max_tokens_to_generate, + max_position_embeddings=self.max_position_embeddings, + add_BOS=True, + add_beginning_of_answer_token=True, +) +``` + +To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails` +with different `full` and `feature` attributes: + +```python +hf_config = self.info.get_hf_config() +bos_token_id = hf_config.bos_token_id # `` +assert isinstance(bos_token_id, int) + +def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ) +``` + +Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt, +we can search for it to conduct the replacement at the start of the string: + +```python +def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, +) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + bos_token_id = hf_config.bos_token_id + assert isinstance(bos_token_id, int) + + tokenizer = self.info.get_tokenizer() + eot_token_id = tokenizer.bos_token_id + assert isinstance(eot_token_id, int) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ) + + return [ + PromptReplacement( + modality="image", + target=[eot_token_id], + replacement=get_replacement_fuyu, + ) + ] +``` + +::: + :::: ## 5. Register processor-related classes diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 50b5ef35d2cd..4e0ee6364f86 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -104,6 +104,8 @@ def get_image_feature_grid_size( image_processor = self.get_image_processor() target_width = image_processor.size["width"] target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] if not (image_width <= target_width and image_height <= target_height): height_scale_factor = target_height / image_height @@ -113,8 +115,8 @@ def get_image_feature_grid_size( image_height = int(image_height * optimal_scale_factor) image_width = int(image_width * optimal_scale_factor) - ncols = math.ceil(image_width / 30) - nrows = math.ceil(image_height / 30) + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) return ncols, nrows def get_image_size_with_most_features(self) -> ImageSize: