Skip to content

Conversation

@jp1924
Copy link
Contributor

@jp1924 jp1924 commented Oct 31, 2024

What does this PR do?

Image features and image tokens do not match: tokens: 2403, features 2349
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llava_next/modeling_llava_next.py", line 921, in forward
    raise ValueError(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/workspace/others/img_mismatch_error.py", line 191, in main
    print(model(**find_inputs))
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/workspace/others/img_mismatch_error.py", line 203, in <module>
    main()
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
ValueError: Image features and image tokens do not match: tokens: 2403, features 2349
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding : current_width - padding]

When calculating new_height and new_width, applying int() to values like 59.999999 (which should be treated as 60) results in 59. This leads to incorrect unpadding and consequently causes the img_size mismatch error.

Fixed the issue by adding round to int(original_width * scale_factor) like int(round(original_width * scale_factor, 7)).

It works fine in my environment for now, but I think we should look into this more carefully.
What do you all think?

bug reproduction code

import torch
from PIL import Image

from transformers import (
    AddedToken,
    AutoConfig,
    AutoTokenizer,
    LlavaNextConfig,
    LlavaNextForConditionalGeneration,
    LlavaNextImageProcessor,
    LlavaNextProcessor,
)


device = "cpu"
dtype = torch.bfloat16
IMG_TOKEN = "<|image|>"

language_name, vision_name = "google/gemma-2-9b", "google/siglip-so400m-patch14-384"
language_config = AutoConfig.from_pretrained(language_name)
vision_config = AutoConfig.from_pretrained(vision_name).vision_config

vision_config.num_hidden_layers, language_config.num_hidden_layers = 2, 2

tokenizer = AutoTokenizer.from_pretrained(language_name)
tokenizer.add_tokens(AddedToken(IMG_TOKEN, special=True, normalized=False), special_tokens=True)
language_config.vocab_size = len(tokenizer.get_vocab())
image_token_index = tokenizer.convert_tokens_to_ids(IMG_TOKEN)
image_grid_pinpoints = [[768, 768], [384, 768], [384, 1152], [768, 384], [1152, 384]]
vision_feature_select_strategy = "full"

config = LlavaNextConfig(
    vision_config=vision_config,
    text_config=language_config,
    image_seq_length=vision_config.image_size,
    image_token_index=image_token_index,
    image_grid_pinpoints=image_grid_pinpoints,
    vision_feature_select_strategy=vision_feature_select_strategy,
    _attn_implementation="eager",
)
image_processor = LlavaNextImageProcessor.from_pretrained(
    vision_name,
    image_grid_pinpoints=image_grid_pinpoints,
    crop_size={"height": vision_config.image_size, "width": vision_config.image_size},
)

processor = LlavaNextProcessor(
    image_processor=image_processor,
    tokenizer=tokenizer,
    image_token=IMG_TOKEN,
    patch_size=vision_config.patch_size,
    vision_feature_select_strategy=vision_feature_select_strategy,
    image_seq_length=vision_config.image_size,
)
model = LlavaNextForConditionalGeneration(config)
model = model.to(device).to(dtype)

# width, hegith
inputs = (
    processor(
        images=Image.new("RGB", (940, 423)),
        text=IMG_TOKEN,
        return_tensors="pt",
    )
    .to(device)
    .to(dtype)
)

output = model(**inputs)

This code requires the issue #34447 to be resolved first before running.
Therefore, it is recommended to resolve this issue by installing the following code before execution:
pip install git+https://github.com/jp1924/transformers.git@Fix--Llava_img_mismatch

env

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.26.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): 2.15.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Makes sense since we use the math.ceil op when padding. Let's use math.ceil for unpadding also to make sure the calculations are done identically

For core maintainers: I don't think this needs a test as it is not something that can be broken again. We have had some issues before with unpadding (that is why image_sizes are cast to list) but still some numerical precision errors seem to be left

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1
Copy link
Member

Pinging actual core maintainer @LysandreJik!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you!

@LysandreJik LysandreJik merged commit 1b86772 into huggingface:main Oct 31, 2024
17 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants