From ec5a7779a7430cd8bde4ca620b3b2b7b85da60a2 Mon Sep 17 00:00:00 2001 From: Mert Atay <73117970+CatcherInThePy@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:43:58 +0300 Subject: [PATCH 1/3] Fix inference error in LlavaNextForConditionalGeneration when no image data is provided This fix addresses an issue where running inference without image data causes an error on line 874, as `image_features` is `None`. The update restructures the checks for `image_features` and `legacy_processing` to properly handle cases where no image data is present, while preserving the integrity of the remaining processing steps. --- .../models/llava_next/modeling_llava_next.py | 128 +++++++++--------- 1 file changed, 63 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index de5a030b61c1..b5f74f58d3ab 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -862,73 +862,74 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) - - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - if input_ids.shape[1] != 1: - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, + + if image_features is not None: + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + if input_ids.shape[1] != 1: + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + # TODO: @raushan retain only the new behavior after v4.47 + else: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, @@ -1007,6 +1008,3 @@ def prepare_inputs_for_generation( model_inputs["image_sizes"] = image_sizes return model_inputs - - -__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"] From 196112d375586cc69ec443ca0aa2bed3c48c8387 Mon Sep 17 00:00:00 2001 From: Mert Atay <73117970+CatcherInThePy@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:55:38 +0300 Subject: [PATCH 2/3] Added missing __all__ variable. Removed whitespace --- src/transformers/models/llava_next/modeling_llava_next.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index b5f74f58d3ab..8974dc6173c9 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -862,7 +862,6 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) - if image_features is not None: if legacy_processing: logger.warning_once( @@ -913,7 +912,7 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - + # TODO: @raushan retain only the new behavior after v4.47 else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() @@ -1008,3 +1007,6 @@ def prepare_inputs_for_generation( model_inputs["image_sizes"] = image_sizes return model_inputs + + +__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"] From 320ead63b44dd3bfb1f4d58fa65c8c07de8bae60 Mon Sep 17 00:00:00 2001 From: Mert Atay <73117970+CatcherInThePy@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:03:46 +0300 Subject: [PATCH 3/3] Code quality fixes Removed empty lines and whitespaces --- src/transformers/models/llava_next/modeling_llava_next.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 8974dc6173c9..9b62e415676c 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -912,7 +912,6 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - # TODO: @raushan retain only the new behavior after v4.47 else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item()