Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 63 additions & 64 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,73 +862,72 @@ def forward(
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)

if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
if input_ids.shape[1] != 1:
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
if image_features is not None:
if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
if input_ids.shape[1] != 1:
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]

# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down