From 0c49735c58e5abf581dcb6630d463d042845d60c Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Fri, 21 Nov 2025 20:27:53 -0800 Subject: [PATCH 1/6] Fix: Normalize batch inputs to 5D tensors for Qwen-Image-Edit --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..5f43b92de194 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -627,7 +627,26 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[-1].size if isinstance(image, list) else image.size + # [Fix] Robustly determine image size (Handles Lists, Tensors, and PIL) + if isinstance(image, list): + # Grab the first valid image to determine dimensions + check_img = image[0] + # Handle potential nested lists (e.g. if batching logic gets complex) + while isinstance(check_img, (list, tuple)): + check_img = check_img[0] + + if isinstance(check_img, torch.Tensor): + # Tensor shape is usually (C, H, W) or (B, C, H, W) -> take last two dims + image_size = (check_img.shape[-1], check_img.shape[-2]) + else: + # PIL Image + image_size = check_img.size + elif isinstance(image, torch.Tensor): + image_size = (image.shape[-1], image.shape[-2]) + else: + # Single PIL Image + image_size = image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width From 1b85230a2cb1283901908020aaf83c5886745c89 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Fri, 21 Nov 2025 21:55:44 -0800 Subject: [PATCH 2/6] Fix: Handle variable sequence lengths in batch inference via padding --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 68 +++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 5f43b92de194..932ad02d86f3 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -283,7 +283,6 @@ def _get_qwen_prompt_embeds( return prompt_embeds, encoder_attention_mask - # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -295,7 +294,6 @@ def encode_prompt( max_sequence_length: int = 1024, ): r""" - Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded @@ -311,17 +309,63 @@ def encode_prompt( """ device = device or self._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + # [Fix] Loop over prompts to avoid Qwen2VLProcessor batching bugs & IndexError + if isinstance(prompt, list) and len(prompt) > 1: + prompt_embeds_list = [] + mask_list = [] + + # Normalize images to a list matching the prompt length + if isinstance(image, list): + current_images = image + else: + current_images = [image] * len(prompt) + + for i, single_prompt in enumerate(prompt): + # Safety: Ensure we have an image for this prompt + single_image = current_images[i] if i < len(current_images) else current_images[0] + + pe, pem = self._get_qwen_prompt_embeds( + single_prompt, + image=single_image, + device=device + ) + prompt_embeds_list.append(pe) + mask_list.append(pem) + + # [Fix] Pad embeddings to the maximum length in the batch before stacking + max_len = max([p.shape[1] for p in prompt_embeds_list]) + + padded_embeds = [] + padded_masks = [] + + for pe, pem in zip(prompt_embeds_list, mask_list): + cur_len = pe.shape[1] + pad_len = max_len - cur_len + + if pad_len > 0: + # Pad sequence dim (2nd last dim for embeds, last dim for mask) + pe = torch.nn.functional.pad(pe, (0, 0, 0, pad_len)) + pem = torch.nn.functional.pad(pem, (0, pad_len)) + + padded_embeds.append(pe) + padded_masks.append(pem) + + prompt_embeds = torch.cat(padded_embeds, dim=0) + prompt_embeds_mask = torch.cat(padded_masks, dim=0) + + else: + # Standard path for single prompt + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) return prompt_embeds, prompt_embeds_mask @@ -639,12 +683,10 @@ def __call__( # Tensor shape is usually (C, H, W) or (B, C, H, W) -> take last two dims image_size = (check_img.shape[-1], check_img.shape[-2]) else: - # PIL Image image_size = check_img.size elif isinstance(image, torch.Tensor): image_size = (image.shape[-1], image.shape[-2]) else: - # Single PIL Image image_size = image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) From da78501aa6c2f8bb87e934e166379fb472707e11 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Sat, 22 Nov 2025 07:23:29 -0800 Subject: [PATCH 3/6] Fix: Add routing logic for 1-to-1 batch inference and tokenizer padding --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 932ad02d86f3..1530f63bc7a5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -729,10 +729,12 @@ def __call__( if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): if not isinstance(image, list): image = [image] + condition_image_sizes = [] condition_images = [] vae_image_sizes = [] vae_images = [] + for img in image: image_width, image_height = img.size condition_width, condition_height = calculate_dimensions( @@ -742,8 +744,16 @@ def __call__( condition_image_sizes.append((condition_width, condition_height)) vae_image_sizes.append((vae_width, vae_height)) condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + + # [5D Fix] Ensure (B, C, F, H, W) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + # [FIX] Handle Batch vs Multi-Condition Ambiguity + # If user provides N prompts and N images, stack them for 1-to-1 batching. + if isinstance(prompt, list) and len(prompt) > 1 and len(vae_images) == len(prompt): + batch_tensor = torch.cat(vae_images, dim=0) + vae_images = [batch_tensor] + has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) From 2b05da83413e78b34a150c98cbc2374c01177ab7 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Sat, 22 Nov 2025 08:53:45 -0800 Subject: [PATCH 4/6] Fix: Robust batch inference with padding and 1-to-1 stacking --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 1530f63bc7a5..94af063f9eab 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -748,12 +748,32 @@ def __call__( # [5D Fix] Ensure (B, C, F, H, W) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) - # [FIX] Handle Batch vs Multi-Condition Ambiguity - # If user provides N prompts and N images, stack them for 1-to-1 batching. + # [FIX] Handle Batch vs Multi-Condition Ambiguity + Variable Resolutions if isinstance(prompt, list) and len(prompt) > 1 and len(vae_images) == len(prompt): - batch_tensor = torch.cat(vae_images, dim=0) + # 1. Find max dims (Height=[-2], Width=[-1]) + max_h = max(img.shape[-2] for img in vae_images) + max_w = max(img.shape[-1] for img in vae_images) + + padded_images = [] + for img in vae_images: + h, w = img.shape[-2], img.shape[-1] + pad_h = max_h - h + pad_w = max_w - w + if pad_h > 0 or pad_w > 0: + # Pad (left, right, top, bottom) + img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h)) + padded_images.append(img) + + # 2. 1-to-1 Batching + batch_tensor = torch.cat(padded_images, dim=0) vae_images = [batch_tensor] + # 3. [FIX] Update metadata to match padded dims - Rotary Positional Embeddings + # We must tell the model that each batch item has exactly 1 condition image with the new padded dimensions. + height = max_h + width = max_w + vae_image_sizes = [(max_w, max_h)] + has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) From f84b0793735faa2133742f82058ef506a620be55 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Sat, 22 Nov 2025 09:16:19 -0800 Subject: [PATCH 5/6] Qwen-Image-Edit robust batch inference with padding --- .../pipeline_controlnet_sd_xl_img2img.py | 3 +- .../cosmos/pipeline_cosmos_video2world.py | 4 +-- .../kandinsky/pipeline_kandinsky_combined.py | 2 +- .../kandinsky/pipeline_kandinsky_img2img.py | 3 +- .../kandinsky/pipeline_kandinsky_inpaint.py | 3 +- .../kandinsky/pipeline_kandinsky_prior.py | 3 +- .../pipeline_kandinsky2_2_combined.py | 2 +- .../pipeline_kandinsky2_2_controlnet.py | 3 +- ...ipeline_kandinsky2_2_controlnet_img2img.py | 3 +- .../pipeline_kandinsky2_2_img2img.py | 3 +- .../pipeline_kandinsky2_2_inpainting.py | 3 +- .../pipeline_kandinsky2_2_prior.py | 3 +- .../pipeline_kandinsky2_2_prior_emb2emb.py | 6 ++-- .../pipeline_pag_controlnet_sd_xl_img2img.py | 3 +- .../qwenimage/pipeline_qwenimage_edit_plus.py | 34 ++++++++----------- 15 files changed, 30 insertions(+), 48 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 4d4845c5a0a3..0311df1d7f7a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -131,8 +131,7 @@ >>> prompt = "A robot, 4k photo" >>> image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ).resize((1024, 1024)) >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> depth_image = get_depth_map(image) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py index cd5a734cc311..763f41373a73 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py @@ -88,9 +88,7 @@ def __init__(self, *args, **kwargs): >>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." >>> video = load_video( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" - ... )[ - ... :21 - ... ] # This example uses only the first 21 frames + ... )[:21] # This example uses only the first 21 frames >>> video = pipe(video=video, prompt=prompt).frames[0] >>> export_to_video(video, "output.mp4", fps=30) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py index 7286bcbee17b..a5952daad420 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -98,7 +98,7 @@ negative_prompt = "low quality, bad quality" original_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index f5e41d499dc3..6469b970e8b5 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -60,8 +60,7 @@ >>> pipe.to("cuda") >>> init_image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/frog.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png" ... ) >>> image = pipe( diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py index 731fce499859..521e6fd8f493 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -66,8 +66,7 @@ >>> pipe.to("cuda") >>> init_image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> mask = np.zeros((768, 768), dtype=np.float32) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py index 10ea8005c90d..eca3d2317392 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -88,8 +88,7 @@ >>> pipe_prior.to("cuda") >>> img1 = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> img2 = load_image( diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index fc2083247bb0..2c44408f299d 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -92,7 +92,7 @@ negative_prompt = "low quality, bad quality" original_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py index c5faae82796b..c14790d61d19 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -71,8 +71,7 @@ >>> img = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py index 54154c6ec1f2..662b81c311d8 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -72,8 +72,7 @@ >>> pipe = pipe.to("cuda") >>> img = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ).resize((768, 768)) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py index 3b2509098fd1..0d7e118cee23 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -56,8 +56,7 @@ >>> pipe.to("cuda") >>> init_image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/frog.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png" ... ) >>> image = pipe( diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py index a61673293e1f..741c5622cd1e 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -62,8 +62,7 @@ >>> pipe.to("cuda") >>> init_image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> mask = np.zeros((768, 768), dtype=np.float32) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py index bc67847831a5..aae892f57136 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -60,8 +60,7 @@ ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py index b586d166118b..7e3bee808d0c 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -35,8 +35,7 @@ >>> prompt = "red cat, 4k photo" >>> img = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() @@ -73,8 +72,7 @@ >>> pipe_prior.to("cuda") >>> img1 = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ) >>> img2 = load_image( diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index a6df1b22c8b9..de956f2c0f2c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -132,8 +132,7 @@ >>> prompt = "A robot, 4k photo" >>> image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - ... "/kandinsky/cat.png" + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" ... ).resize((1024, 1024)) >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> depth_image = get_depth_map(image) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 94af063f9eab..247d4c5ce76f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -313,7 +313,7 @@ def encode_prompt( if isinstance(prompt, list) and len(prompt) > 1: prompt_embeds_list = [] mask_list = [] - + # Normalize images to a list matching the prompt length if isinstance(image, list): current_images = image @@ -323,36 +323,32 @@ def encode_prompt( for i, single_prompt in enumerate(prompt): # Safety: Ensure we have an image for this prompt single_image = current_images[i] if i < len(current_images) else current_images[0] - - pe, pem = self._get_qwen_prompt_embeds( - single_prompt, - image=single_image, - device=device - ) + + pe, pem = self._get_qwen_prompt_embeds(single_prompt, image=single_image, device=device) prompt_embeds_list.append(pe) mask_list.append(pem) - + # [Fix] Pad embeddings to the maximum length in the batch before stacking max_len = max([p.shape[1] for p in prompt_embeds_list]) - + padded_embeds = [] padded_masks = [] - + for pe, pem in zip(prompt_embeds_list, mask_list): cur_len = pe.shape[1] pad_len = max_len - cur_len - + if pad_len > 0: # Pad sequence dim (2nd last dim for embeds, last dim for mask) pe = torch.nn.functional.pad(pe, (0, 0, 0, pad_len)) pem = torch.nn.functional.pad(pem, (0, pad_len)) - + padded_embeds.append(pe) padded_masks.append(pem) prompt_embeds = torch.cat(padded_embeds, dim=0) prompt_embeds_mask = torch.cat(padded_masks, dim=0) - + else: # Standard path for single prompt prompt = [prompt] if isinstance(prompt, str) else prompt @@ -678,7 +674,7 @@ def __call__( # Handle potential nested lists (e.g. if batching logic gets complex) while isinstance(check_img, (list, tuple)): check_img = check_img[0] - + if isinstance(check_img, torch.Tensor): # Tensor shape is usually (C, H, W) or (B, C, H, W) -> take last two dims image_size = (check_img.shape[-1], check_img.shape[-2]) @@ -729,12 +725,12 @@ def __call__( if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): if not isinstance(image, list): image = [image] - + condition_image_sizes = [] condition_images = [] vae_image_sizes = [] vae_images = [] - + for img in image: image_width, image_height = img.size condition_width, condition_height = calculate_dimensions( @@ -744,7 +740,7 @@ def __call__( condition_image_sizes.append((condition_width, condition_height)) vae_image_sizes.append((vae_width, vae_height)) condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) - + # [5D Fix] Ensure (B, C, F, H, W) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) @@ -753,7 +749,7 @@ def __call__( # 1. Find max dims (Height=[-2], Width=[-1]) max_h = max(img.shape[-2] for img in vae_images) max_w = max(img.shape[-1] for img in vae_images) - + padded_images = [] for img in vae_images: h, w = img.shape[-2], img.shape[-1] @@ -763,7 +759,7 @@ def __call__( # Pad (left, right, top, bottom) img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h)) padded_images.append(img) - + # 2. 1-to-1 Batching batch_tensor = torch.cat(padded_images, dim=0) vae_images = [batch_tensor] From 22dd1cb9e75191ada47638819d4e4278d466addd Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Sun, 23 Nov 2025 21:22:27 -0800 Subject: [PATCH 6/6] Fix Qwen batching: final logic for input resizing and robust inference --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 177 +++++++++++------- 1 file changed, 107 insertions(+), 70 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 247d4c5ce76f..2347e1f76574 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -226,6 +226,25 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor return split_result + def _sanitize_images(self, image): + """ + Recursively unwraps tuples or lists to find valid PIL Images or Tensors. Solves the issue where + `load_image(...),` creates nested tuples like `((Image,),)`. + """ + if isinstance(image, (list, tuple)): + # If it's a list/tuple, check if it contains images or nested tuples + unwrapped_images = [] + for img in image: + while isinstance(img, (list, tuple)): + img = img[0] + unwrapped_images.append(img) + return unwrapped_images + + # Handle single input that might be wrapped + while isinstance(image, (list, tuple)): + image = image[0] + return [image] + def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -236,22 +255,32 @@ def _get_qwen_prompt_embeds( device = device or self._execution_device dtype = dtype or self.text_encoder.dtype + # Ensure prompt is a list prompt = [prompt] if isinstance(prompt, str) else prompt + + # Ensure image is a list matching the prompt length + # This is critical for the Processor to map index 0 -> index 0 + if image is not None: + if not isinstance(image, list): + image = [image] + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" - if isinstance(image, list): + + # Logic to handle multiple images per SINGLE prompt + if isinstance(image, list) and len(image) > len(prompt): base_img_prompt = "" for i, img in enumerate(image): base_img_prompt += img_prompt_template.format(i + 1) - elif image is not None: - base_img_prompt = img_prompt_template.format(1) else: - base_img_prompt = "" + base_img_prompt = img_prompt_template.format(1) template = self.prompt_template_encode - drop_idx = self.prompt_template_encode_start_idx + + # formatting txt = [template.format(base_img_prompt + e) for e in prompt] + # ensure processor gets lists model_inputs = self.processor( text=txt, images=image, @@ -293,23 +322,9 @@ def encode_prompt( prompt_embeds_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 1024, ): - r""" - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - image (`torch.Tensor`, *optional*): - image to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - """ device = device or self._execution_device - # [Fix] Loop over prompts to avoid Qwen2VLProcessor batching bugs & IndexError + # [FIX]: Robust Batch Handling Loop if isinstance(prompt, list) and len(prompt) > 1: prompt_embeds_list = [] mask_list = [] @@ -321,14 +336,15 @@ def encode_prompt( current_images = [image] * len(prompt) for i, single_prompt in enumerate(prompt): - # Safety: Ensure we have an image for this prompt single_image = current_images[i] if i < len(current_images) else current_images[0] + # Pass single items, so the processor sees + # text=["..."] and images=[img] (1-to-1). pe, pem = self._get_qwen_prompt_embeds(single_prompt, image=single_image, device=device) prompt_embeds_list.append(pe) mask_list.append(pem) - # [Fix] Pad embeddings to the maximum length in the batch before stacking + # Pad embeddings to the maximum length in the batch max_len = max([p.shape[1] for p in prompt_embeds_list]) padded_embeds = [] @@ -339,7 +355,6 @@ def encode_prompt( pad_len = max_len - cur_len if pad_len > 0: - # Pad sequence dim (2nd last dim for embeds, last dim for mask) pe = torch.nn.functional.pad(pe, (0, 0, 0, pad_len)) pem = torch.nn.functional.pad(pem, (0, pad_len)) @@ -352,16 +367,14 @@ def encode_prompt( else: # Standard path for single prompt prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_images_per_prompt, dim=0) return prompt_embeds, prompt_embeds_mask @@ -668,6 +681,9 @@ def __call__( returning a tuple, the first element is a list with the generated images. """ # [Fix] Robustly determine image size (Handles Lists, Tensors, and PIL) + if image is not None: + image = self._sanitize_images(image) + if isinstance(image, list): # Grab the first valid image to determine dimensions check_img = image[0] @@ -723,52 +739,72 @@ def __call__( device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - if not isinstance(image, list): - image = [image] + image = self._sanitize_images(image) condition_image_sizes = [] condition_images = [] vae_image_sizes = [] vae_images = [] + # We first calculate what size each image WANTS to be (preserving aspect ratio) + ideal_sizes = [] for img in image: - image_width, image_height = img.size - condition_width, condition_height = calculate_dimensions( - CONDITION_IMAGE_SIZE, image_width / image_height - ) - vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) - condition_image_sizes.append((condition_width, condition_height)) - vae_image_sizes.append((vae_width, vae_height)) - condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + w, h = img.size + # vw, vh = calculate_dimensions(VAE_IMAGE_SIZE, w / h) + vw = ( + round(w / 32) * 32 + ) # uncomment above line and change w -> vw and h -> vh if you want to upscale everything to 1024 + vh = round(h / 32) * 32 + ideal_sizes.append((vw, vh)) + + # If set(ideal_sizes) has length 1, they are all uniform! + all_same_size = len(set(ideal_sizes)) == 1 + + # Default target is 1024 (Standard Qwen) + force_tgt_w, force_tgt_h = 1024, 1024 + + for idx, img in enumerate(image): + w, h = img.size - # [5D Fix] Ensure (B, C, F, H, W) - vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + # Condition Image (Always keeps aspect ratio) + cw, ch = calculate_dimensions(CONDITION_IMAGE_SIZE, w / h) + condition_image_sizes.append((cw, ch)) + condition_images.append(self.image_processor.resize(img, ch, cw)) - # [FIX] Handle Batch vs Multi-Condition Ambiguity + Variable Resolutions + # VAE SIZE LOGIC + if height is not None and width is not None: + # Priority 1: User specified strict size; Force it + vw, vh = width, height + + elif all_same_size: + # Priority 2: Batch is uniform; Keep "Ideal" (Preserves Aspect Ratio) + vw, vh = ideal_sizes[idx] + + else: + # Priority 3: Batch is mixed (Portrait + Landscape); Force Square + vw, vh = force_tgt_w, force_tgt_h + + # Ensure divisible by 32 + vw = round(vw / 32) * 32 + vh = round(vh / 32) * 32 + + vae_image_sizes.append((vw, vh)) + + # If Mixed Batch: stretch them to 1024x1024 + # If Uniform Batch: preserve aspect ratio + vae_images.append(self.image_processor.preprocess(img, vh, vw).unsqueeze(2)) + + # Batching Logic if isinstance(prompt, list) and len(prompt) > 1 and len(vae_images) == len(prompt): - # 1. Find max dims (Height=[-2], Width=[-1]) - max_h = max(img.shape[-2] for img in vae_images) - max_w = max(img.shape[-1] for img in vae_images) - - padded_images = [] - for img in vae_images: - h, w = img.shape[-2], img.shape[-1] - pad_h = max_h - h - pad_w = max_w - w - if pad_h > 0 or pad_w > 0: - # Pad (left, right, top, bottom) - img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h)) - padded_images.append(img) - - # 2. 1-to-1 Batching - batch_tensor = torch.cat(padded_images, dim=0) + batch_tensor = torch.cat(vae_images, dim=0) vae_images = [batch_tensor] - # 3. [FIX] Update metadata to match padded dims - Rotary Positional Embeddings - # We must tell the model that each batch item has exactly 1 condition image with the new padded dimensions. - height = max_h - width = max_w - vae_image_sizes = [(max_w, max_h)] + # Update global metadata for pipeline + if height is None or width is None: + if all_same_size: + width, height = vae_image_sizes[0] + else: + width, height = force_tgt_w, force_tgt_h has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -817,15 +853,16 @@ def __call__( generator, latents, ) - img_shapes = [ - [ + img_shapes = [] + for i in range(batch_size): + # Safe access to size + vw, vh = vae_image_sizes[i] if i < len(vae_image_sizes) else vae_image_sizes[0] + + shape_entry = [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), - *[ - (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) - for vae_width, vae_height in vae_image_sizes - ], + (1, vh // self.vae_scale_factor // 2, vw // self.vae_scale_factor // 2), ] - ] * batch_size + img_shapes.append(shape_entry) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas