Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,26 @@ class StableDiffusionXLControlNetInpaintPipeline(
"""

model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"image_encoder",
"feature_extractor",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
"mask",
"masked_image_latents",
]

def __init__(
self,
Expand Down
16 changes: 15 additions & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,15 @@ class StableDiffusionXLControlNetPipeline(
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]

def __init__(
self,
Expand Down Expand Up @@ -1528,6 +1536,12 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
]

def __init__(
self,
Expand Down Expand Up @@ -1584,6 +1592,12 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,15 @@ class StableDiffusionXLControlNetXSPipeline(
"text_encoder_2",
"feature_extractor",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]

def __init__(
self,
Expand Down
45 changes: 44 additions & 1 deletion tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)

from diffusers import (
AutoencoderKL,
Expand All @@ -34,6 +42,7 @@
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
Expand All @@ -55,6 +64,14 @@ class ControlNetPipelineSDXLFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset(IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"mask_image", "control_image"}))
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{
"add_text_embeds",
"add_time_ids",
"mask",
"masked_image_latents",
}
)

def get_dummy_components(self):
torch.manual_seed(0)
Expand Down Expand Up @@ -129,6 +146,30 @@ def get_dummy_components(self):
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)

image_encoder = CLIPVisionModelWithProjection(image_encoder_config)

feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)

components = {
"unet": unet,
"controlnet": controlnet,
Expand All @@ -138,6 +179,8 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
}
return components

Expand Down
5 changes: 5 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
Expand All @@ -55,9 +56,13 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
):
pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)

def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
Expand Down