Closed
Description
Describe the bug
Chroma has attention masks on both the transformer and T5. Right now only T5 computes with the proper mask. The attention masks from T5 should carry forward into the transformer and mask padding tokens except for the last padding token. See the reference implementation.
https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py#L222-L257
Reproduction
import torch
from diffusers import ChromaTransformer2DModel
from transformers import T5EncoderModel, T5Tokenizer
from optimum.quanto import freeze, qint8, qfloat8, quantize
from chroma_pipeline import ChromaPipeline
bfl_repo = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
transformer = ChromaTransformer2DModel.from_single_file(
"chroma-unlocked-v37-detail-calibrated.safetensors",
torch_dtype=dtype,
).to('cuda')
quantize(transformer, weights=qint8, exclude=["proj_out", "x_embedder", "norm_out",
"context_embedder", "distilled_guidance_layer"])
freeze(transformer)
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype).to('cuda')
quantize(text_encoder, weights=qint8)
freeze(text_encoder)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder,
tokenizer=tokenizer, torch_dtype=dtype).to('cuda')
generator = torch.Generator().manual_seed(0)
image = pipe(
prompt='pikachu playing a violin on mars, sign in the background says, "welcome to mars!!',
negative_prompt="",
width=1024,
height=1024,
num_inference_steps=35,
generator=generator,
guidance_scale=5,
).images[0]
image.save('chroma/test_chroma0.png')
Logs
System Info
linux 24.04, 3090, diffusers main