diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0053074bad8e..612219ad43aa 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -492,6 +492,7 @@ "QwenImageImg2ImgPipeline", "QwenImageInpaintPipeline", "QwenImagePipeline", + "QwenImageEditPipeline", "ReduxImageEncoder", "SanaControlNetPipeline", "SanaPAGPipeline", @@ -1123,6 +1124,7 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + QwenImageEditPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 3dfecb783745..049e69a4beb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import functools import math from typing import Any, Dict, List, Optional, Tuple, Union @@ -161,9 +160,9 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim - pos_index = torch.arange(1024) - neg_index = torch.arange(1024).flip(0) * -1 - 1 - pos_freqs = torch.cat( + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), @@ -171,7 +170,7 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): ], dim=1, ) - neg_freqs = torch.cat( + self.neg_freqs = torch.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), @@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): dim=1, ) self.rope_cache = {} - self.register_buffer("pos_freqs", pos_freqs, persistent=False) - self.register_buffer("neg_freqs", neg_freqs, persistent=False) - # 是否使用 scale rope + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): @@ -201,35 +198,47 @@ def forward(self, video_fhw, txt_seq_lens, device): Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + if isinstance(video_fhw, list): video_fhw = video_fhw[0] - frame, height, width = video_fhw - rope_key = f"{frame}_{height}_{width}" - - if not torch.compiler.is_compiling(): - if rope_key not in self.rope_cache: - self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width) - vid_freqs = self.rope_cache[rope_key] - else: - vid_freqs = self._compute_video_freqs(frame, height, width) + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) + vid_freqs.append(video_freq) - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width): + def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 535b23dbb4ee..6b0394b4868a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -391,6 +391,7 @@ "QwenImagePipeline", "QwenImageImg2ImgPipeline", "QwenImageInpaintPipeline", + "QwenImageEditPipeline", ] try: if not is_onnx_available(): @@ -708,7 +709,12 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline + from .qwenimage import ( + QwenImageEditPipeline, + QwenImageImg2ImgPipeline, + QwenImageInpaintPipeline, + QwenImagePipeline, + ) from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py index 64265880e72f..3d0378511fe0 100644 --- a/src/diffusers/pipelines/qwenimage/__init__.py +++ b/src/diffusers/pipelines/qwenimage/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] + _import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -35,6 +36,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_qwenimage import QwenImagePipeline + from .pipeline_qwenimage_edit import QwenImageEditPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline else: diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 47549ab4af00..8f695f07dd68 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -253,6 +253,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -316,20 +319,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) @@ -402,8 +391,7 @@ def prepare_latents( shape = (batch_size, 1, num_channels_latents, height, width) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -414,9 +402,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids + return latents @property def guidance_scale(self): @@ -594,7 +580,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -604,7 +590,7 @@ def __call__( generator, latents, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py new file mode 100644 index 000000000000..942210c1fdb5 --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -0,0 +1,870 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Change the cat to a dog" + >>> image = Image.open("cat.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimageedit.png") + ``` +""" +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.latent_channels = 16 + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image.width / image.height) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + prompt_image = image + image = self.image_processor.preprocess(image, image_height, image_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + # negative image is the same size as the original image, but all pixels are white + # negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255)) + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 4fc84a31cc6e..c9ee0aba1dc9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -296,6 +296,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -363,21 +366,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -465,8 +453,7 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype) if image.shape[1] != self.latent_channels: @@ -489,9 +476,7 @@ def prepare_latents( latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids + return latents @property def guidance_scale(self): @@ -713,7 +698,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + latents = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, @@ -725,7 +710,7 @@ def __call__( generator, latents, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 5ffec0c447ff..05da95f3cd5e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -307,6 +307,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -390,21 +393,6 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") - @staticmethod - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - @staticmethod # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -492,8 +480,7 @@ def prepare_latents( raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype) if image.shape[1] != self.latent_channels: @@ -524,9 +511,7 @@ def prepare_latents( image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, noise, image_latents, latent_image_ids + return latents, noise, image_latents def prepare_mask_latents( self, @@ -859,7 +844,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, noise, image_latents, latent_image_ids = self.prepare_latents( + latents, noise, image_latents = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, @@ -894,7 +879,7 @@ def __call__( generator, ) - img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e02457bf8df9..181cbdbc66a0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1742,6 +1742,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]