diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 77cfc57b42bc..cd015f83ae97 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -461,18 +461,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, @@ -480,66 +471,14 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 8dca2e16dff6..098b6fb379b6 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -154,27 +154,17 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text if image_inputs.get("pixel_values") is not None: - if self.patch_size is not None and self.vision_feature_select_strategy is not None: - # Replace the image token with the expanded image token sequence - pixel_values = image_inputs["pixel_values"] - height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * ( - width // self.patch_size - ) + self.num_additional_image_tokens - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - - prompt_strings = [] - for sample in text: - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) - else: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index de5a030b61c1..1a1223e9c2cf 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -689,7 +689,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select image_feature = torch.cat( ( image_feature, - image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype), + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), ), dim=-1, ) @@ -835,18 +837,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -863,70 +856,14 @@ def forward( image_newline=self.image_newline, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - if input_ids.shape[1] != 1: - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 462006bc99ea..cc293a416b38 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -149,30 +149,19 @@ def __call__( prompt_strings = text if image_inputs: - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - else: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - if not isinstance(image_size, (list, tuple)): - # cast to list to avoid numerical precision errors when calculating unpadding - image_size = image_size.tolist() - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7cd7e18abaf3..f6a66a7a9b11 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -722,7 +722,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select image_feature = torch.cat( ( image_feature, - image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype), + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), ), dim=-1, ) @@ -909,25 +911,9 @@ def forward( "and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -942,7 +928,17 @@ def forward( image_newline=self.image_newline, ) - video_features = video_feature_lens = None + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: video_features = self.get_video_features( pixel_values_videos, @@ -954,95 +950,16 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 94c1432a41b1..5c04c96b8877 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -431,25 +431,9 @@ def forward( "and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -464,7 +448,17 @@ def forward( image_newline=self.image_newline, ) - video_features = video_feature_lens = None + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: video_features = self.get_video_features( pixel_values_videos, @@ -476,95 +470,16 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index 857ee28a0800..f3b2b78f7aa6 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -173,48 +173,33 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size`, `num_additional_image_tokens` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` " - "and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - else: - # images expand taking into account num_of_patches in each image - if image_inputs: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - if not isinstance(image_size, (list, tuple)): - # cast to list to avoid numerical precision errors when calculating unpadding - image_size = image_size.tolist() - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= self.num_additional_image_tokens - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - text = [sample.replace("", self.image_token) for sample in prompt_strings] - - # videos are easier, simply get frames and multiply - if videos_inputs: - one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(one_video[0]) - num_frames = one_video.shape[0] # frame dim is always after batch dim - - # no `self.num_additional_image_tokens` added because video always has a default feature selection strategy - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) - num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer - prompt_strings = [] - for sample in text: - sample = sample.replace(self.video_token, self.video_token * num_video_tokens) - prompt_strings.append(sample) - text = prompt_strings + if image_inputs: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] + + # videos are easier, simply get frames and multiply + if videos_inputs: + one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer + prompt_strings = [] + for sample in text: + sample = sample.replace(self.video_token, self.video_token * num_video_tokens) + prompt_strings.append(sample) + text = prompt_strings text_inputs = self.tokenizer( text, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 802304e6b191..80dfaa2f0e74 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -538,127 +538,41 @@ def forward( "time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and ( - pixel_values_images is not None or pixel_values_videos is not None - ) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = None if pixel_values_images is not None: image_features = self.get_image_features( pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - video_features = None - num_frames = 0 if pixel_values_videos is not None: video_features, num_frames = self.get_video_features( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - if input_ids.shape[1] != 1: - for features, frames in ((image_features, 1), (video_features, num_frames)): - if features is not None: - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - input_ids, - ) = self._merge_input_ids_with_visual_features( - features, - inputs_embeds, - input_ids, - attention_mask, - labels, - num_frames=frames, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if pixel_values_images is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] * video_features.shape[1] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] * video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 01a1a950346c..3f58675d047a 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -158,16 +158,8 @@ def __call__( raise ValueError("Invalid input text. Please provide a string, or a list of strings") prompt_strings = text - if encoded_images is not None and (self.patch_size is None or self.vision_feature_select_strategy is None): - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set " - "directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = " - "{{vision_feature_select_strategy}}`. Using processors without these attributes in the config is " - "deprecated and will throw an error in v4.50." - ) - # Replace the image/video tokens with the expanded token sequence - elif encoded_images is not None: + + if encoded_images is not None: if "pixel_values_images" in encoded_images.keys(): height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0])) num_frames = 1 diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index b4a6c6ae9bb9..0eb65b0fc722 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -455,80 +455,22 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in VipLLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." - ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # in the case one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index b4a959a00d2a..e91b76f7d9f5 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -327,10 +327,7 @@ def test_small_model_integration_test(self): prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" image_file = "https://llava-vl.github.io/static/images/view.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") - - EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip - self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device) output = model.generate(**inputs, max_new_tokens=20) EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip @@ -378,7 +375,7 @@ def test_small_model_integration_test_llama_batched(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to(torch_device) output = model.generate(**inputs, max_new_tokens=20) @@ -402,7 +399,9 @@ def test_small_model_integration_test_batch(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to( + torch_device + ) output = model.generate(**inputs, max_new_tokens=20) @@ -434,7 +433,9 @@ def test_small_model_integration_test_llama_batched_regression(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True).to( + torch_device + ) output = model.generate(**inputs, max_new_tokens=20) @@ -508,32 +509,18 @@ def test_llava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore model_id = "llava-hf/llava-1.5-7b-hf" model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) - # Simulate some user inputs - pixel_values = torch.randn( - (1, 3, 336, 336), - dtype=torch.float, - device=torch_device, - ) - input_ids = torch.tensor( - [ - [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], - ], - dtype=torch.long, - device=torch_device, - ) - attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1]], - dtype=torch.long, - device=torch_device, - ) + prompt = "USER: \nDescribe the imageASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) # Make sure that the loss is properly computed loss = model( - pixel_values=pixel_values, - input_ids=input_ids, - attention_mask=attention_mask, - labels=input_ids, + **inputs, + labels=inputs.input_ids.clone(), ).loss loss.backward() @@ -593,38 +580,6 @@ def test_generation_siglip_backbone(self): EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat" self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) - @slow - @require_bitsandbytes - def test_expansion_in_processing(self): - model_id = "llava-hf/llava-1.5-7b-hf" - model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) - - prompt = "USER: \nDescribe the image:\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.num_additional_image_tokens = 1 - processor.patch_size = 14 - inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) - self.assertTrue(inputs.input_ids.shape[-1] == 18) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) - @slow @require_bitsandbytes def test_pixtral(self): diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index d3a66a16df9a..3e6c1a9a969f 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -50,7 +50,7 @@ def tearDown(self): shutil.rmtree(self.tmpdirname) def prepare_processor_dict(self): - return {"chat_template": "dummy_template"} + return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 14b0fb8cc07d..c90dbe056e51 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -396,8 +396,10 @@ def test_small_model_integration_test(self): ) original_input_ids = torch.load(filepath, map_location="cpu") # replace -200 by image_token_index (since we use token ID = 32000 for the image token) - original_input_ids[original_input_ids == -200] = model.config.image_token_index - assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() + # remove image token indices because HF impl expands image tokens `image_seq_length` times + original_input_ids = original_input_ids[original_input_ids != -200] + observed_input_ids = inputs.input_ids[inputs.input_ids != model.config.image_token_index] + assert original_input_ids[0].tolist() == observed_input_ids[0].tolist() filepath = hf_hub_download( repo_id="nielsr/test-image", @@ -414,7 +416,7 @@ def test_small_model_integration_test(self): expected_slice = torch.tensor( [[-4.7695, -4.5664, -0.2788], [-10.6172, -10.8828, -2.5273], [-6.7383, -7.2422, -0.6694]], - dtype=torch.float32, + dtype=torch.float16, device=torch_device, ) assert torch.allclose(output.logits[0, :3, :3], expected_slice, atol=1e-3) @@ -518,11 +520,11 @@ def test_small_model_integration_test_batch_different_resolutions(self): expected_slice = torch.tensor( [[-0.1287, -0.1294, -0.1284], [-0.2744, -0.2698, -0.2671], [-0.1071, -0.1091, -0.1056]], - dtype=torch.float32, + dtype=torch.float16, device=torch_device, ) assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3) - assert torch.allclose(output.loss, torch.tensor(7.0206, device=torch_device), atol=1e-3) + assert torch.allclose(output.loss, torch.tensor(7.0206, dtype=torch.float16, device=torch_device), atol=1e-3) # verify generation output = model.generate(**inputs, max_new_tokens=50) @@ -601,80 +603,6 @@ def test_padding_side_when_merging_inputs(self): self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs) - @slow - @require_bitsandbytes - def test_expansion_in_processing_multiimage(self): - model_id = "llava-hf/llava-v1.6-mistral-7b-hf" - model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) - - prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - deer_image = Image.open( - requests.get( - "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", - stream=True, - ).raw - ) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.patch_size = 14 - processor.num_additional_image_tokens = 1 - inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( - torch_device, torch.float16 - ) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3969) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( - torch_device, torch.float16 - ) - self.assertTrue(inputs.input_ids.shape[-1] == 23) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) - - @slow - @require_bitsandbytes - def test_expansion_in_processing(self): - model_id = "llava-hf/llava-v1.6-mistral-7b-hf" - model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) - - prompt = "USER: \nDescribe the image:\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.patch_size = 14 - processor.num_additional_image_tokens = 1 - inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) - self.assertTrue(inputs.input_ids.shape[-1] == 17) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) - @slow @require_bitsandbytes def test_small_model_integration_test_full_vision_state_selection(self): @@ -685,7 +613,7 @@ def test_small_model_integration_test_full_vision_state_selection(self): # test that changing `strategy` won't error out model.vision_feature_select_strategy = "full" - inputs = self.processor(self.prompt, self.image, return_tensors="pt") + inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device) # verify generation output = model.generate(**inputs, max_new_tokens=30) diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index 45faa2452630..234e47911000 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -27,7 +27,7 @@ if is_vision_available(): - from transformers import CLIPImageProcessor + from transformers import LlavaNextImageProcessor @require_vision @@ -37,7 +37,7 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - image_processor = CLIPImageProcessor() + image_processor = LlavaNextImageProcessor() tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") processor_kwargs = self.prepare_processor_dict() processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs) @@ -50,7 +50,7 @@ def get_image_processor(self, **kwargs): return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor def prepare_processor_dict(self): - return {"chat_template": "dummy_template"} + return {"chat_template": "dummy_template", "patch_size": 3, "vision_feature_select_strategy": "default"} @unittest.skip( "Skip because the model has no processor kwargs except for chat template and" diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index c431f91bf510..b0234fef34e8 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -17,7 +17,6 @@ import unittest import numpy as np -import requests from huggingface_hub import hf_hub_download from transformers import ( @@ -543,107 +542,3 @@ def test_padding_side_when_merging_inputs(self): model(**inputs_batched, output_hidden_states=True) self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs) - - @slow - @require_bitsandbytes - def test_expansion_in_processing(self): - model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" - model = LlavaNextVideoForConditionalGeneration.from_pretrained( - "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True - ) - processor = AutoProcessor.from_pretrained(model_id) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.patch_size = 14 - processor.num_additional_image_tokens = 1 - inputs_expanded = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 1170) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device) - self.assertTrue(inputs.input_ids.shape[-1] == 19) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) - - @slow - @require_bitsandbytes - def test_expansion_in_processing_images(self): - model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" - model = LlavaNextVideoForConditionalGeneration.from_pretrained( - "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True - ) - processor = AutoProcessor.from_pretrained(model_id) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.patch_size = 14 - processor.num_additional_image_tokens = 1 - inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) - self.assertTrue(inputs.input_ids.shape[-1] == 19) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) - - @slow - @require_bitsandbytes - def test_expansion_in_processing_multiimage(self): - model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" - model = LlavaNextVideoForConditionalGeneration.from_pretrained( - "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True - ) - processor = AutoProcessor.from_pretrained(model_id) - - prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - deer_image = Image.open( - requests.get( - "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", - stream=True, - ).raw - ) - - # check processing with expansion of inputs - processor.vision_feature_select_strategy = "default" - processor.patch_size = 14 - processor.num_additional_image_tokens = 1 - inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( - torch_device, torch.float16 - ) - self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3968) - - # check processing without expansion of inputs (legacy behavior) - processor.vision_feature_select_strategy = None - processor.patch_size = None - processor.num_additional_image_tokens = None - inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( - torch_device, torch.float16 - ) - self.assertTrue(inputs.input_ids.shape[-1] == 22) - - # generate exactly 20 tokens - output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) - output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) - - # check that both inputs are handled correctly and generate the same output - self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 14b079665ab6..4a7bcb45b010 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -127,7 +127,6 @@ def __init__( self.num_image_tokens = (vision_config["image_size"] // vision_config["patch_size"]) ** 2 self.num_video_tokens = (self.num_image_tokens + 1) * self.num_frames self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens - self.encoder_seq_length = self.seq_length def get_config(self): return VideoLlavaConfig( @@ -185,22 +184,6 @@ def prepare_config_and_inputs_for_common(self): } return config, inputs_dict - def prepare_config_and_inputs_for_batched_test(self): - config_and_inputs = self.prepare_config_and_inputs() - config, _, pixel_values_videos = config_and_inputs - input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 - attention_mask = input_ids.ne(1).to(torch_device) - - # make sure no other special tokens are set - input_ids[(input_ids == 0) | (input_ids == 1)] = 3 - input_ids[:, 0] = config.video_token_index - inputs_dict = { - "pixel_values_videos": pixel_values_videos, - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - @require_torch class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -339,7 +322,7 @@ def recursive_check(batched_object, single_row_object, model_name, key): ), ) - config, batched_input = self.model_tester.prepare_config_and_inputs_for_batched_test() + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: config.output_hidden_states = True @@ -457,11 +440,11 @@ def test_small_model_integration_test(self): repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" ) video_file = np.load(video_file) - inputs = self.processor(prompt, videos=video_file, return_tensors="pt") - - EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32001, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip + inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device) - self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + EXPECTED_INPUT_IDS = torch.tensor([1, 3148, 1001, 29901, 29871, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901], device=torch_device) # fmt: skip + non_video_inputs = inputs["input_ids"][inputs["input_ids"] != 32001] + self.assertTrue(torch.equal(non_video_inputs, EXPECTED_INPUT_IDS)) output = model.generate(**inputs, do_sample=False, max_new_tokens=20) EXPECTED_DECODED_TEXT = "USER: \nWhy is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and reading a book, which" # fmt: skip @@ -487,7 +470,9 @@ def test_small_model_integration_test_mixed_inputs(self): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) - inputs = self.processor(prompts, images=[image], videos=[video_file], padding=True, return_tensors="pt") + inputs = self.processor(prompts, images=[image], videos=[video_file], padding=True, return_tensors="pt").to( + torch_device + ) output = model.generate(**inputs, do_sample=False, max_new_tokens=20) EXPECTED_DECODED_TEXT = [ @@ -543,7 +528,7 @@ def test_small_model_integration_test_llama_batched(self): hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset") ) - inputs = processor(prompts, videos=[video_1, video_2], return_tensors="pt", padding=True) + inputs = processor(prompts, videos=[video_1, video_2], return_tensors="pt", padding=True).to(torch_device) output = model.generate(**inputs, max_new_tokens=20) @@ -583,96 +568,16 @@ def test_video_llava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True) - # Simulate some user inputs - pixel_values_videos = torch.randn( - (1, 8, 3, 224, 224), - dtype=torch.float, - device=torch_device, - ) - # fmt: off - input_ids = torch.tensor( - [[32002, 32002, 1, 15043, 7084, 32001, 29871, 13, 7900]], - dtype=torch.long, - device=torch_device, - ) - # fmt: on - attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], - dtype=torch.long, - device=torch_device, + prompt = "USER: