-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Description
System Info
4.47.1
Who can help?
vision models: @amyeroberts, @qubvel
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
transformers/src/transformers/models/llava_next/modeling_llava_next.py
Lines 866 to 916 in 5cabc75
| if legacy_processing: | |
| logger.warning_once( | |
| "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " | |
| "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " | |
| "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " | |
| "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." | |
| ) | |
| if input_ids.shape[1] != 1: | |
| inputs_embeds = inputs_embeds.to(image_features.dtype) | |
| inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( | |
| image_features, | |
| feature_lens, | |
| inputs_embeds, | |
| input_ids, | |
| attention_mask, | |
| position_ids, | |
| labels=labels, | |
| ) | |
| cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) | |
| else: | |
| # Retrieve the first layer to inspect the logits and mask out the hidden states | |
| # that are set to 0 | |
| first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] | |
| # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
| batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | |
| # Get the target length | |
| target_length = input_ids.shape[1] | |
| past_length = first_layer_past_key_value.shape[-1] | |
| extended_attention_mask = torch.ones( | |
| (attention_mask.shape[0], past_length), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device, | |
| ) | |
| # Filter out only the tokens that can be un-attended, this can happen | |
| # if one uses Llava + Fused modules where the cache on the | |
| # first iteration is already big enough, or if one passes custom cache | |
| valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | |
| new_batch_index = batch_index[valid_indices] | |
| new_non_attended_tokens = non_attended_tokens[valid_indices] | |
| # Zero-out the places where we don't need to attend | |
| extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | |
| attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) | |
| position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
| cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] | |
| # TODO: @raushan retain only the new behavior after v4.47 |
Sample script
def main():
args = parse_args()
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
).to("cuda:0")
setup_model_with_compression(model, args)
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
Expected behavior
how does the legacy processing work? can I disable it ?