Skip to content

Conversation

@zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Jun 21, 2024

What does this PR do?

Fixes #31505.

Adds Chameleon, a vision language model from Meta AI.

from transformers import ChameleonForCausalLM, ChameleonProcessor
from PIL import Image
import requests
import torch

model_path = "MODEL_PATH"
model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
processor = ChameleonProcessor.from_pretrained(model_path)

prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)

inputs = processor(prompt, images=[image], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
generated_text = processor.batch_decode(out, skip_special_tokens=False)[0]
print(f"Generated text: {generated_text}")

# Multi-image example
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)

inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
generated_text = processor.batch_decode(out, skip_special_tokens=True)[0]
print(f"Generated text: {generated_text}")
>>> 

Project repo: https://github.com/facebookresearch/chameleon
Paper: https://arxiv.org/abs/2405.09818v1

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jsm69
Copy link

jsm69 commented Jun 22, 2024

@amyeroberts @ArthurZucker

Comment on lines 308 to 310
# we need to expand on num_heads because there was not sharding done in 7B model
# and we need to calculate mean/var over each head_dim
# for sharded model we don't do expansion and simply do norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to bake that by updating the alpha and beta for model parallelisme

Comment on lines 319 to 322
# permute key/value to use transformers RoPE implementation (see for more: https://github.com/huggingface/transformers/issues/25199)
# NOTE: permutation is done same way as in llama conversion script
key_states = key_states.view(-1, self.num_key_value_heads, self.head_dim // 2, 2).transpose(3, 2)
query_states = query_states.view(-1, self.num_heads, self.head_dim // 2, 2).transpose(3, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should permute everything in the weights

Comment on lines +206 to +208
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we are computing over say "head_dim" ?
How different is this from normal nn.layer_norm((head_dim,))
?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the weights are different (hidden_size) but the applied dim is hidden_size is that it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's only the weights that are different for 30B model. The 7B has simple repeated weights over all heads


for token in tokenizer_config["added_tokens"]:
if token["content"] == "<reserved08707>":
token["content"] = "<image>"
Copy link
Contributor

@leloykun leloykun Jul 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp @ArthurZucker We should also set token["special"] = False so that we can decode this token.

What do you guys think?

(It's what I'm currently doing in my PR btw and I haven't encountered any errors yet)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Special tokens should be decodable with skip_special_tokens=False

For my understanding, why do we need to decode the image token? Afaik it shouldn't affect image generation because it's a token we added manually to keep track of where to add an image in the text

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 great work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Meta FAIR Chameleon 7b and 30b