Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/flux2.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
>
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Caption upsampling

Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.

## Flux2Pipeline

[[autodoc]] Flux2Pipeline
Expand Down
46 changes: 43 additions & 3 deletions src/diffusers/pipelines/flux2/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import Tuple
from typing import List

import PIL.Image

Expand Down Expand Up @@ -96,16 +96,24 @@ def check_image_input(
)

return image

@staticmethod
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size

scale = math.sqrt(target_area / (image_width * image_height))
width = int(image_width * scale)
height = int(image_height * scale)

return image.resize((width, height), PIL.Image.Resampling.LANCZOS)

@staticmethod
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size
pixel_count = image_width * image_height
if pixel_count <= target_area:
return image
return Flux2ImageProcessor._resize_to_target_area(image, target_area)

def _resize_and_crop(
self,
Expand Down Expand Up @@ -136,3 +144,35 @@ def _resize_and_crop(
bottom = top + height

return image.crop((left, top, right, bottom))

# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
@staticmethod
def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""

# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()

# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]

# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)

# Create new image with white background
background_color = (255, 255, 255)
new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)

# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width

return new_img
192 changes: 172 additions & 20 deletions src/diffusers/pipelines/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I


if is_torch_xla_available():
Expand Down Expand Up @@ -56,25 +57,107 @@
```
"""

UPSAMPLING_MAX_IMAGE_SIZE = 768**2

def format_text_input(prompts: List[str], system_message: str = None):
# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
def format_input(
prompts: List[str],
system_message: str = SYSTEM_MESSAGE,
images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
):
"""
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
to the input.

Args:
prompts: List of text prompts
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
images (optional): List of images to add to the input.

Returns:
List of conversations, where each conversation is a list of message dicts
"""
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]

return [
if images is None or len(images) == 0:
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
else:
assert len(images) == len(prompts), "Number of images must match number of prompts"
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
]
for _ in cleaned_txt
]

for i, (el, images) in enumerate(zip(messages, images)):
# optionally add the images per batch element.
if images is not None:
el.append(
{
"role": "user",
"content": [{"type": "image", "image": image_obj} for image_obj in images],
}
)
# add the text.
el.append(
{
"role": "user",
"content": [{"type": "text", "text": cleaned_txt[i]}],
}
)

return messages


# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
def _validate_and_process_images(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a seperate step to validate and process image and then run format_input?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We now first _validate_and_process_images() and then pass the resultant images to format_input().

images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
image_processor: Flux2ImageProcessor,
upsampling_max_image_size: int,
) -> List[List[PIL.Image.Image]]:
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
if not images:
return []

# Check if it's a list of lists or a list of images
if isinstance(images[0], PIL.Image.Image):
# It's a list of images, convert to list of lists
images = [[im] for im in images]

# potentially concatenate multiple images to reduce the size
images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]

# cap the pixels
images = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size)
for img_i in img_i
]
for prompt in cleaned_txt
for img_i in images
]
return images


# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
Expand Down Expand Up @@ -214,9 +297,10 @@ def __init__(
self.tokenizer_max_length = 512
self.default_sample_size = 128

# fmt: off
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
# fmt: on
self.system_message = SYSTEM_MESSAGE
self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE

@staticmethod
def _get_mistral_3_small_prompt_embeds(
Expand All @@ -226,9 +310,7 @@ def _get_mistral_3_small_prompt_embeds(
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
system_message: str = SYSTEM_MESSAGE,
hidden_states_layers: List[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
Expand All @@ -237,7 +319,7 @@ def _get_mistral_3_small_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt

# Format input messages
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
messages_batch = format_input(prompts=prompt, system_message=system_message)

# Process all messages at once
inputs = tokenizer.apply_chat_template(
Expand Down Expand Up @@ -426,6 +508,68 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch

return torch.stack(x_list, dim=0)

def upsample_prompt(
self,
prompt: Union[str, List[str]],
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
temperature: float = 0.15,
device: torch.device = None,
) -> List[str]:
prompt = [prompt] if isinstance(prompt, str) else prompt
device = self.text_encoder.device if device is None else device

# Set system message based on whether images are provided
if images is None or len(images) == 0 or images[0] is None:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I

# Validate and process the input images
if images:
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)

# Format input messages
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)

# Process all messages at once
# with image processing a too short max length can throw an error in here.
inputs = self.tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=2048,
)

# Move to device
inputs["input_ids"] = inputs["input_ids"].to(device)
inputs["attention_mask"] = inputs["attention_mask"].to(device)

if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype)

# Generate text using the model's generate method
generated_ids = self.text_encoder.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=temperature,
use_cache=True,
)

# Decode only the newly generated tokens (skip input tokens)
# Extract only the generated portion
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated_ids[:, input_length:]

upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return upsampled_prompt

def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -620,6 +764,7 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
caption_upsample_temperature: float = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand All @@ -635,11 +780,11 @@ def __call__(
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
guidance_scale (`float`, *optional*, defaults to 1.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
a model to generate images more aligned with `prompt` at the expense of lower image quality.

Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -684,6 +829,9 @@ def __call__(
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
text_encoder_out_layers (`Tuple[int]`):
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
caption_upsample_temperature (`float`):
When specified, we will try to perform caption upsampling for potentially improved outputs. We
recommend setting it to 0.15 if caption upsampling is to be performed.

Examples:

Expand Down Expand Up @@ -718,6 +866,10 @@ def __call__(
device = self._execution_device

# 3. prepare text embeddings
if caption_upsample_temperature:
prompt = self.upsample_prompt(
prompt, images=image, temperature=caption_upsample_temperature, device=device
)
prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
Expand Down
29 changes: 29 additions & 0 deletions src/diffusers/pipelines/flux2/system_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
These system prompts come from:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed internally, this new-line character thingy messes up the quality a bit. Hence, I have decided to keep these system messages one-to-one same as the original implementation linked above.

If we run make style && make quality, this order will be completely destroyed. We can change the pyproject.toml to exclude this path from getting formatted. But before we do that, let's see if this is the best we have.

https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
"""

SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""

SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
Guidelines:
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
Output only the revised prompt and nothing else."""

SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
Rules:
- Single instruction only, no commentary
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
- Specify what changes AND what stays the same (face, lighting, composition)
- Reference actual image elements
- Turn negatives into positives ("don't change X" → "keep X")
- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
- Keep content PG-13
Output only the final instruction in plain text and nothing else."""
Loading