Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 2 additions & 63 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,85 +461,24 @@ def forward(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)

legacy_processing = False
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)

image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)

if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
# prefill stage vs decoding stage (legacy behavior copied)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]

# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]

if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

Expand Down
32 changes: 11 additions & 21 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,17 @@ def __call__(
# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_image_tokens = (height // self.patch_size) * (
width // self.patch_size
) + self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= self.num_additional_image_tokens

prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
else:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1

prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)

text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs})
Expand Down
73 changes: 5 additions & 68 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
Expand Down Expand Up @@ -835,18 +837,9 @@ def forward(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)

legacy_processing = False
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)

image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
Expand All @@ -863,70 +856,14 @@ def forward(
image_newline=self.image_newline,
)

if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
if input_ids.shape[1] != 1:
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]

# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

Expand Down
37 changes: 13 additions & 24 deletions src/transformers/models/llava_next/processing_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,30 +149,19 @@ def __call__(

prompt_strings = text
if image_inputs:
if self.patch_size is None or self.vision_feature_select_strategy is None:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
else:
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
if not isinstance(image_size, (list, tuple)):
# cast to list to avoid numerical precision errors when calculating unpadding
image_size = image_size.tolist()
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= self.num_additional_image_tokens
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]

text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

Expand Down
Loading