From 357860c1de18d17b2e46fdce0d0f609b2a0cfadf Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 05:55:48 +0000 Subject: [PATCH 01/11] update --- tests/pipelines/controlnetxs/test_controlnetxs.py | 2 ++ tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/pipelines/controlnetxs/test_controlnetxs.py b/tests/pipelines/controlnetxs/test_controlnetxs.py index 1f184e5bb14c..1e9e523b71db 100644 --- a/tests/pipelines/controlnetxs/test_controlnetxs.py +++ b/tests/pipelines/controlnetxs/test_controlnetxs.py @@ -106,6 +106,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): out_queue.join() +@unittest.skip("Move to Community Pipelines") class ControlNetXSPipelineFastTests( PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): @@ -243,6 +244,7 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 +@unittest.skip("Move to Community Pipelines") @slow @require_torch_gpu class ControlNetXSPipelineSlowTests(unittest.TestCase): diff --git a/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py index dbdc532a6f3b..06e957ce4edc 100644 --- a/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py @@ -48,6 +48,7 @@ enable_full_determinism() +@unittest.skip("Move to Community Pipelines") class StableDiffusionXLControlNetXSPipelineFastTests( PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, @@ -307,6 +308,7 @@ def test_stable_diffusion_xl_prompt_embeds(self): assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 +@unittest.skip("Move to Community Pipelines") @slow @require_torch_gpu class ControlNetSDXLPipelineXSSlowTests(unittest.TestCase): From dc981518b485d93a9f259379abbd534bb9b9aaae Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 09:53:42 +0000 Subject: [PATCH 02/11] update --- .../community/controlnetxs/controlnetxs.py | 1016 ++++++++++++++++ .../controlnetxs/infer_sd_controlnetxs.py | 54 + .../controlnetxs/infer_sdxl_controlnetxs.py | 53 + .../controlnetxs/pipeline_controlnet_xs.py | 901 ++++++++++++++ .../pipeline_controlnet_xs_sd_xl.py | 1078 +++++++++++++++++ 5 files changed, 3102 insertions(+) create mode 100644 examples/community/controlnetxs/controlnetxs.py create mode 100644 examples/community/controlnetxs/infer_sd_controlnetxs.py create mode 100644 examples/community/controlnetxs/infer_sdxl_controlnetxs.py create mode 100644 examples/community/controlnetxs/pipeline_controlnet_xs.py create mode 100644 examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py diff --git a/examples/community/controlnetxs/controlnetxs.py b/examples/community/controlnetxs/controlnetxs.py new file mode 100644 index 000000000000..c6419b44daeb --- /dev/null +++ b/examples/community/controlnetxs/controlnetxs.py @@ -0,0 +1,1016 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.normalization import GroupNorm + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.lora import LoRACompatibleConv +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + Downsample2D, + ResnetBlock2D, + Transformer2DModel, + UpBlock2D, + Upsample2D, +) +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetXSOutput(BaseOutput): + """ + The output of [`ControlNetXSModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model + output, but is already the final output. + """ + + sample: torch.FloatTensor = None + + +# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A ControlNet-XS model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation + of [`UNet2DConditionModel`] for them. + + Parameters: + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + time_embedding_input_dim (`int`, defaults to 320): + Dimension of input into time embedding. Needs to be same as in the base model. + time_embedding_dim (`int`, defaults to 1280): + Dimension of output from time embedding. Needs to be same as in the base model. + learn_embedding (`bool`, defaults to `False`): + Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of + the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the + control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. + base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): + Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. + """ + + @classmethod + def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): + """ + Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). + + Parameters: + base_model (`UNet2DConditionModel`): + Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. + is_sdxl (`bool`, defaults to `True`): + Whether passed `base_model` is a StableDiffusion-XL model. + """ + + def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): + """ + Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). + The original ControlNet-XS model, however, define the number of attention heads. + That's why compute the dimensions needed to get the correct number of attention heads. + """ + block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] + dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] + return dim_attn_heads + + if is_sdxl: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=0.95, + learn_embedding=True, + size_ratio=0.1, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), + ) + else: + return ControlNetXSModel.from_unet( + base_model, + time_embedding_mix=1.0, + learn_embedding=True, + size_ratio=0.0125, + conditioning_embedding_out_channels=(16, 32, 96, 256), + num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), + ) + + @classmethod + def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): + """To create correctly sized connections between base and control model, we need to know + the input and output channels of each subblock. + + Parameters: + unet (`UNet2DConditionModel`): + Unet of which the subblock channels sizes are to be gathered. + base_or_control (`str`): + Needs to be either "base" or "control". If "base", decoder is also considered. + """ + if base_or_control not in ["base", "control"]: + raise ValueError("`base_or_control` needs to be either `base` or `control`") + + channel_sizes = {"down": [], "mid": [], "up": []} + + # input convolution + channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) + + # encoder blocks + for module in unet.down_blocks: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + for r in module.resnets: + channel_sizes["down"].append((r.in_channels, r.out_channels)) + if module.downsamplers: + channel_sizes["down"].append( + (module.downsamplers[0].channels, module.downsamplers[0].out_channels) + ) + else: + raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") + + # middle block + channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) + + # decoder blocks + if base_or_control == "base": + for module in unet.up_blocks: + if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): + for r in module.resnets: + channel_sizes["up"].append((r.in_channels, r.out_channels)) + else: + raise ValueError( + f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." + ) + + return channel_sizes + + @register_to_config + def __init__( + self, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + time_embedding_input_dim: int = 320, + time_embedding_dim: int = 1280, + time_embedding_mix: float = 1.0, + learn_embedding: bool = False, + base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { + "down": [ + (4, 320), + (320, 320), + (320, 320), + (320, 320), + (320, 640), + (640, 640), + (640, 640), + (640, 1280), + (1280, 1280), + ], + "mid": [(1280, 1280)], + "up": [ + (2560, 1280), + (2560, 1280), + (1920, 1280), + (1920, 640), + (1280, 640), + (960, 640), + (960, 320), + (640, 320), + (640, 320), + ], + }, + sample_size: Optional[int] = None, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + upcast_attention: bool = False, + ): + super().__init__() + + # 1 - Create control unet + self.control_model = UNet2DConditionModel( + sample_size=sample_size, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + attention_head_dim=num_attention_heads, + use_linear_projection=True, + upcast_attention=upcast_attention, + time_embedding_dim=time_embedding_dim, + ) + + # 2 - Do model surgery on control model + # 2.1 - Allow to use the same time information as the base model + adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) + + # 2.2 - Allow for information infusion from base model + + # We concat the output of each base encoder subblocks to the input of the next control encoder subblock + # (We ignore the 1st element, as it represents the `conv_in`.) + extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] + it_extra_input_channels = iter(extra_input_channels) + + for b, block in enumerate(self.control_model.down_blocks): + for r in range(len(block.resnets)): + increase_block_input_in_encoder_resnet( + self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) + ) + + if block.downsamplers: + increase_block_input_in_encoder_downsampler( + self.control_model, block_no=b, by=next(it_extra_input_channels) + ) + + increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) + + # 2.3 - Make group norms work with modified channel sizes + adjust_group_norms(self.control_model) + + # 3 - Gather Channel Sizes + self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") + self.ch_inout_base = base_model_channel_sizes + + # 4 - Build connections between base and control model + self.down_zero_convs_out = nn.ModuleList([]) + self.down_zero_convs_in = nn.ModuleList([]) + self.middle_block_out = nn.ModuleList([]) + self.middle_block_in = nn.ModuleList([]) + self.up_zero_convs_out = nn.ModuleList([]) + self.up_zero_convs_in = nn.ModuleList([]) + + for ch_io_base in self.ch_inout_base["down"]: + self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) + for i in range(len(self.ch_inout_ctrl["down"])): + self.down_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) + ) + + self.middle_block_out = self._make_zero_conv( + self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] + ) + + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) + ) + for i in range(1, len(self.ch_inout_ctrl["down"])): + self.up_zero_convs_out.append( + self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) + ) + + # 5 - Create conditioning hint embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + # In the mininal implementation setting, we only need the control model up to the mid block + del self.control_model.up_blocks + del self.control_model.conv_norm_out + del self.control_model.conv_out + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + conditioning_channels: int = 3, + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), + controlnet_conditioning_channel_order: str = "rgb", + learn_embedding: bool = False, + time_embedding_mix: float = 1.0, + block_out_channels: Optional[Tuple[int]] = None, + size_ratio: Optional[float] = None, + num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, + norm_num_groups: Optional[int] = None, + ): + r""" + Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. + conditioning_channels (`int`, defaults to 3): + Number of channels of conditioning input (e.g. an image) + conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `controlnet_cond_embedding` layer. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + learn_embedding (`bool`, defaults to `False`): + Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation + of the time embeddings of the control and base model with interpolation parameter + `time_embedding_mix**3`. + time_embedding_mix (`float`, defaults to 1.0): + Linear interpolation parameter used if `learn_embedding` is `True`. + block_out_channels (`Tuple[int]`, *optional*): + Down blocks output channels in control model. Either this or `size_ratio` must be given. + size_ratio (float, *optional*): + When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. + Either this or `block_out_channels` must be given. + num_attention_heads (`Union[int, Tuple[int]]`, *optional*): + The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + norm_num_groups (int, *optional*, defaults to `None`): + The number of groups to use for the normalization of the control unet. If `None`, + `int(unet.config.norm_num_groups * size_ratio)` is taken. + """ + + # Check input + fixed_size = block_out_channels is not None + relative_size = size_ratio is not None + if not (fixed_size ^ relative_size): + raise ValueError( + "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." + ) + + # Create model + if block_out_channels is None: + block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] + + # Check that attention heads and group norms match channel sizes + # - attention heads + def attn_heads_match_channel_sizes(attn_heads, channel_sizes): + if isinstance(attn_heads, (tuple, list)): + return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) + else: + return all(c % attn_heads == 0 for c in channel_sizes) + + num_attention_heads = num_attention_heads or unet.config.attention_head_dim + if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): + raise ValueError( + f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." + ) + + # - group norms + def group_norms_match_channel_sizes(num_groups, channel_sizes): + return all(c % num_groups == 0 for c in channel_sizes) + + if norm_num_groups is None: + if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): + norm_num_groups = unet.config.norm_num_groups + else: + norm_num_groups = min(block_out_channels) + + if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): + print( + f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." + ) + else: + raise ValueError( + f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." + ) + + def get_time_emb_input_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_1.in_features + + def get_time_emb_dim(unet: UNet2DConditionModel): + return unet.time_embedding.linear_2.out_features + + # Clone params from base unet if + # (i) it's required to build SD or SDXL, and + # (ii) it's not used for the time embedding (as time embedding of control model is never used), and + # (iii) it's not set further below anyway + to_keep = [ + "cross_attention_dim", + "down_block_types", + "sample_size", + "transformer_layers_per_block", + "up_block_types", + "upcast_attention", + ] + kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} + kwargs.update(block_out_channels=block_out_channels) + kwargs.update(num_attention_heads=num_attention_heads) + kwargs.update(norm_num_groups=norm_num_groups) + + # Add controlnetxs-specific params + kwargs.update( + conditioning_channels=conditioning_channels, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + time_embedding_input_dim=get_time_emb_input_dim(unet), + time_embedding_dim=get_time_emb_dim(unet), + time_embedding_mix=time_embedding_mix, + learn_embedding=learn_embedding, + base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + return cls(**kwargs) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + return self.control_model.attn_processors + + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + self.control_model.set_attn_processor(processor, _remove_lora) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.control_model.set_default_attn_processor() + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + self.control_model.set_attention_slice(slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (UNet2DConditionModel)): + if value: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() + + def forward( + self, + base_model: UNet2DConditionModel, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + base_model (`UNet2DConditionModel`): + The base unet model we want to control. + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + How much the control model affects the base model outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # scale control strength + n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) + scale_list = torch.full((n_connections,), conditioning_scale) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = base_model.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + if self.config.learn_embedding: + ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) + base_temb = base_model.time_embedding(t_emb, timestep_cond) + interpolation_param = self.config.time_embedding_mix**0.3 + + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = base_model.time_embedding(t_emb) + + # added time & text embeddings + aug_emb = None + + if base_model.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if base_model.config.class_embed_type == "timestep": + class_labels = base_model.time_proj(class_labels) + + class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) + temb = temb + class_emb + + if base_model.config.addition_embed_type is not None: + if base_model.config.addition_embed_type == "text": + aug_emb = base_model.add_embedding(encoder_hidden_states) + elif base_model.config.addition_embed_type == "text_image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = base_model.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(temb.dtype) + aug_emb = base_model.add_embedding(add_embeds) + elif base_model.config.addition_embed_type == "image": + raise NotImplementedError() + elif base_model.config.addition_embed_type == "image_hint": + raise NotImplementedError() + + temb = temb + aug_emb if aug_emb is not None else temb + + # text embeddings + cemb = encoder_hidden_states + + # Preparation + guided_hint = self.controlnet_cond_embedding(controlnet_cond) + + h_ctrl = h_base = sample + hs_base, hs_ctrl = [], [] + it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( + iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) + ) + scales = iter(scale_list) + + base_down_subblocks = to_sub_blocks(base_model.down_blocks) + ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) + base_mid_subblocks = to_sub_blocks([base_model.mid_block]) + ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) + base_up_subblocks = to_sub_blocks(base_model.up_blocks) + + # Cross Control + # 0 - conv in + h_base = base_model.conv_in(h_base) + h_ctrl = self.control_model.conv_in(h_ctrl) + if guided_hint is not None: + h_ctrl += guided_hint + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 1 - down + for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base + hs_base.append(h_base) + hs_ctrl.append(h_ctrl) + + # 2 - mid + h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl + for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock + h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock + h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base + + # 3 - up + for i, m_base in enumerate(base_up_subblocks): + h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder + h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder + h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) + + h_base = base_model.conv_norm_out(h_base) + h_base = base_model.conv_act(h_base) + h_base = base_model.conv_out(h_base) + + if not return_dict: + return h_base + + return ControlNetXSOutput(sample=h_base) + + def _make_zero_conv(self, in_channels, out_channels=None): + # keep running track of channels sizes + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + + return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + @torch.no_grad() + def _check_if_vae_compatible(self, vae: AutoencoderKL): + condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) + vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + compatible = condition_downscale_factor == vae_downscale_factor + return compatible, condition_downscale_factor, vae_downscale_factor + + +class SubBlock(nn.ModuleList): + """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. + Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. + """ + + def __init__(self, ms, *args, **kwargs): + if not is_iterable(ms): + ms = [ms] + super().__init__(ms, *args, **kwargs) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor, + cemb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """Iterate through children and pass correct information to each.""" + for m in self: + if isinstance(m, ResnetBlock2D): + x = m(x, temb) + elif isinstance(m, Transformer2DModel): + x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample + elif isinstance(m, Downsample2D): + x = m(x) + elif isinstance(m, Upsample2D): + x = m(x) + else: + raise ValueError( + f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" + ) + + return x + + +def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): + unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) + + +def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): + """Increase channels sizes to allow for additional concatted information from base model""" + r = unet.down_blocks[block_no].resnets[resnet_idx] + old_norm1, old_conv1 = r.norm1, r.conv1 + # norm + norm_args = "num_groups num_channels eps affine".split(" ") + for a in norm_args: + assert hasattr(old_norm1, a) + norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} + norm_kwargs["num_channels"] += by # surgery done here + # conv1 + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + + for a in conv1_args: + assert hasattr(old_conv1, a) + + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} + conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. + conv1_kwargs["in_channels"] += by # surgery done here + # conv_shortcut + # as we changed the input size of the block, the input and output sizes are likely different, + # therefore we need a conv_shortcut (simply adding won't work) + conv_shortcut_args_kwargs = { + "in_channels": conv1_kwargs["in_channels"], + "out_channels": conv1_kwargs["out_channels"], + # default arguments from resnet.__init__ + "kernel_size": 1, + "stride": 1, + "padding": 0, + "bias": True, + } + # swap old with new modules + unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) + unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) + unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here + + +def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): + """Increase channels sizes to allow for additional concatted information from base model""" + old_down = unet.down_blocks[block_no].downsamplers[0].conv + + args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + args.append("lora_layer") + + for a in args: + assert hasattr(old_down, a) + kwargs = {a: getattr(old_down, a) for a in args} + kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. + kwargs["in_channels"] += by # surgery done here + # swap old with new modules + unet.down_blocks[block_no].downsamplers[0].conv = ( + nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) + ) + unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here + + +def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): + """Increase channels sizes to allow for additional concatted information from base model""" + m = unet.mid_block.resnets[0] + old_norm1, old_conv1 = m.norm1, m.conv1 + # norm + norm_args = "num_groups num_channels eps affine".split(" ") + for a in norm_args: + assert hasattr(old_norm1, a) + norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} + norm_kwargs["num_channels"] += by # surgery done here + conv1_args = [ + "in_channels", + "out_channels", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "bias", + "padding_mode", + ] + if not USE_PEFT_BACKEND: + conv1_args.append("lora_layer") + + conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} + conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. + conv1_kwargs["in_channels"] += by # surgery done here + # conv_shortcut + # as we changed the input size of the block, the input and output sizes are likely different, + # therefore we need a conv_shortcut (simply adding won't work) + conv_shortcut_args_kwargs = { + "in_channels": conv1_kwargs["in_channels"], + "out_channels": conv1_kwargs["out_channels"], + # default arguments from resnet.__init__ + "kernel_size": 1, + "stride": 1, + "padding": 0, + "bias": True, + } + # swap old with new modules + unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) + unet.mid_block.resnets[0].conv1 = ( + nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) + ) + unet.mid_block.resnets[0].conv_shortcut = ( + nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) + ) + unet.mid_block.resnets[0].in_channels += by # surgery done here + + +def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): + def find_denominator(number, start): + if start >= number: + return number + while start != 0: + residual = number % start + if residual == 0: + return start + start -= 1 + + for block in [*unet.down_blocks, unet.mid_block]: + # resnets + for r in block.resnets: + if r.norm1.num_groups < max_num_group: + r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) + + if r.norm2.num_groups < max_num_group: + r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) + + # transformers + if hasattr(block, "attentions"): + for a in block.attentions: + if a.norm.num_groups < max_num_group: + a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) + + +def is_iterable(o): + if isinstance(o, str): + return False + try: + iter(o) + return True + except TypeError: + return False + + +def to_sub_blocks(blocks): + if not is_iterable(blocks): + blocks = [blocks] + + sub_blocks = [] + + for b in blocks: + if hasattr(b, "resnets"): + if hasattr(b, "attentions") and b.attentions is not None: + for r, a in zip(b.resnets, b.attentions): + sub_blocks.append([r, a]) + + num_resnets = len(b.resnets) + num_attns = len(b.attentions) + + if num_resnets > num_attns: + # we can have more resnets than attentions, so add each resnet as separate subblock + for i in range(num_attns, num_resnets): + sub_blocks.append([b.resnets[i]]) + else: + for r in b.resnets: + sub_blocks.append([r]) + + # upsamplers are part of the same subblock + if hasattr(b, "upsamplers") and b.upsamplers is not None: + for u in b.upsamplers: + sub_blocks[-1].extend([u]) + + # downsamplers are own subblock + if hasattr(b, "downsamplers") and b.downsamplers is not None: + for d in b.downsamplers: + sub_blocks.append([d]) + + return list(map(SubBlock, sub_blocks)) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/examples/community/controlnetxs/infer_sd_controlnetxs.py b/examples/community/controlnetxs/infer_sd_controlnetxs.py new file mode 100644 index 000000000000..f456e74db8e1 --- /dev/null +++ b/examples/community/controlnetxs/infer_sd_controlnetxs.py @@ -0,0 +1,54 @@ +# !pip install opencv-python transformers accelerate +import argparse + +import cv2 +import numpy as np +import torch +from controlnetxs import ControlNetXSModel +from PIL import Image +from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + +from diffusers.utils import load_image + + +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting") +parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") +parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) +parser.add_argument("--image_path", type=str, default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") +parser.add_argument("--num_inference_steps", type=int, default=50) + +args = parser.parse_args() + +prompt = args.prompt +negative_prompt = args.negative_prompt +# download an image +image = load_image(args.image_path) + +# initialize the models and pipeline +controlnet_conditioning_scale = args.controlnet_conditioning_scale +controlnet = ControlNetXSModel.from_pretrained( + "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 +) +pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 +) +pipe.enable_model_cpu_offload() + +# get canny image +image = np.array(image) +image = cv2.Canny(image, 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +canny_image = Image.fromarray(image) + +num_inference_steps = args.num_inference_steps + +# generate image +image = pipe( + prompt, + controlnet_conditioning_scale=controlnet_conditioning_scale, + image=canny_image, + num_inference_steps=num_inference_steps +).images[0] +image.save("cnxs_sd.canny.png") diff --git a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py new file mode 100644 index 000000000000..b9ace7959d7a --- /dev/null +++ b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py @@ -0,0 +1,53 @@ +# !pip install opencv-python transformers accelerate +import argparse + +import cv2 +import numpy as np +import torch +from controlnetxs import ControlNetXSModel +from PIL import Image +from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + +from diffusers.utils import load_image + + +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting") +parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") +parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) +parser.add_argument("--image_path", type=str, default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") +parser.add_argument("--num_inference_steps", type=int, default=50) + +args = parser.parse_args() + +prompt = args.prompt +negative_prompt = args.negative_prompt +# download an image +image = load_image(args.image_path) +# initialize the models and pipeline +controlnet_conditioning_scale = args.controlnet_conditioning_scale +controlnet = ControlNetXSModel.from_pretrained( + "UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 +) +pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 +) +pipe.enable_model_cpu_offload() + +# get canny image +image = np.array(image) +image = cv2.Canny(image, 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +canny_image = Image.fromarray(image) + +num_inference_steps = args.num_inference_steps + +# generate image +image = pipe( + prompt, + controlnet_conditioning_scale=controlnet_conditioning_scale, + image=canny_image, + num_inference_steps=num_inference_steps +).images[0] +image.save("cnxs_sd.canny.png") diff --git a/examples/community/controlnetxs/pipeline_controlnet_xs.py b/examples/community/controlnetxs/pipeline_controlnet_xs.py new file mode 100644 index 000000000000..8e95306da584 --- /dev/null +++ b/examples/community/controlnetxs/pipeline_controlnet_xs.py @@ -0,0 +1,901 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from controlnetxs import ControlNetXSModel +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetXSModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( + vae + ) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + base_model=self.unet, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py new file mode 100644 index 000000000000..be888d7e1145 --- /dev/null +++ b/examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1078 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.import_utils import is_invisible_watermark_available +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLControlNetXSPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetXSModel`]: + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetXSModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( + vae + ) + if not vae_compatible: + raise ValueError( + f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetXSModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetXSModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + dont_control = ( + i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end + ) + if dont_control: + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + else: + noise_pred = self.controlnet( + base_model=self.unet, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) From 9489cd43b931e27d6924bf65c6c9a5f427d16375 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 09:55:18 +0000 Subject: [PATCH 03/11] update --- src/diffusers/__init__.py | 5 - src/diffusers/models/__init__.py | 2 - src/diffusers/models/controlnetxs.py | 1016 --------------- src/diffusers/pipelines/__init__.py | 10 - .../controlnet_xs/pipeline_controlnet_xs.py | 946 -------------- .../pipeline_controlnet_xs_sd_xl.py | 1119 ----------------- 6 files changed, 3098 deletions(-) delete mode 100644 src/diffusers/models/controlnetxs.py delete mode 100644 src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py delete mode 100644 src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 10c5b0f46565..d6778961dcba 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -80,7 +80,6 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", - "ControlNetXSModel", "Kandinsky3UNet", "ModelMixin", "MotionAdapter", @@ -256,7 +255,6 @@ "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", - "StableDiffusionControlNetXSPipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", @@ -280,7 +278,6 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", @@ -617,7 +614,6 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, - StableDiffusionControlNetXSPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, @@ -641,7 +637,6 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, - StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6e7fe72bc949..36dbe14c5053 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,7 +32,6 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] - _import_structure["controlnetxs"] = ["ControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] @@ -67,7 +66,6 @@ ConsistencyDecoderVAE, ) from .controlnet import ControlNetModel - from .controlnetxs import ControlNetXSModel from .dual_transformer_2d import DualTransformer2DModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/controlnetxs.py b/src/diffusers/models/controlnetxs.py deleted file mode 100644 index 41fe624b9b59..000000000000 --- a/src/diffusers/models/controlnetxs.py +++ /dev/null @@ -1,1016 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import functional as F -from torch.nn.modules.normalization import GroupNorm - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor -from .autoencoders import AutoencoderKL -from .lora import LoRACompatibleConv -from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - Downsample2D, - ResnetBlock2D, - Transformer2DModel, - UpBlock2D, - Upsample2D, -) -from .unet_2d_condition import UNet2DConditionModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetXSOutput(BaseOutput): - """ - The output of [`ControlNetXSModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model - output, but is already the final output. - """ - - sample: torch.FloatTensor = None - - -# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -class ControlNetXSModel(ModelMixin, ConfigMixin): - r""" - A ControlNet-XS model - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic - methods implemented for all models (such as downloading or saving). - - Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation - of [`UNet2DConditionModel`] for them. - - Parameters: - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - time_embedding_input_dim (`int`, defaults to 320): - Dimension of input into time embedding. Needs to be same as in the base model. - time_embedding_dim (`int`, defaults to 1280): - Dimension of output from time embedding. Needs to be same as in the base model. - learn_embedding (`bool`, defaults to `False`): - Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of - the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the - control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. - base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): - Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. - """ - - @classmethod - def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): - """ - Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). - - Parameters: - base_model (`UNet2DConditionModel`): - Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. - is_sdxl (`bool`, defaults to `True`): - Whether passed `base_model` is a StableDiffusion-XL model. - """ - - def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): - """ - Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). - The original ControlNet-XS model, however, define the number of attention heads. - That's why compute the dimensions needed to get the correct number of attention heads. - """ - block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] - dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] - return dim_attn_heads - - if is_sdxl: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=0.95, - learn_embedding=True, - size_ratio=0.1, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), - ) - else: - return ControlNetXSModel.from_unet( - base_model, - time_embedding_mix=1.0, - learn_embedding=True, - size_ratio=0.0125, - conditioning_embedding_out_channels=(16, 32, 96, 256), - num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), - ) - - @classmethod - def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): - """To create correctly sized connections between base and control model, we need to know - the input and output channels of each subblock. - - Parameters: - unet (`UNet2DConditionModel`): - Unet of which the subblock channels sizes are to be gathered. - base_or_control (`str`): - Needs to be either "base" or "control". If "base", decoder is also considered. - """ - if base_or_control not in ["base", "control"]: - raise ValueError("`base_or_control` needs to be either `base` or `control`") - - channel_sizes = {"down": [], "mid": [], "up": []} - - # input convolution - channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) - - # encoder blocks - for module in unet.down_blocks: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - for r in module.resnets: - channel_sizes["down"].append((r.in_channels, r.out_channels)) - if module.downsamplers: - channel_sizes["down"].append( - (module.downsamplers[0].channels, module.downsamplers[0].out_channels) - ) - else: - raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") - - # middle block - channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) - - # decoder blocks - if base_or_control == "base": - for module in unet.up_blocks: - if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): - for r in module.resnets: - channel_sizes["up"].append((r.in_channels, r.out_channels)) - else: - raise ValueError( - f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." - ) - - return channel_sizes - - @register_to_config - def __init__( - self, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - time_embedding_input_dim: int = 320, - time_embedding_dim: int = 1280, - time_embedding_mix: float = 1.0, - learn_embedding: bool = False, - base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { - "down": [ - (4, 320), - (320, 320), - (320, 320), - (320, 320), - (320, 640), - (640, 640), - (640, 640), - (640, 1280), - (1280, 1280), - ], - "mid": [(1280, 1280)], - "up": [ - (2560, 1280), - (2560, 1280), - (1920, 1280), - (1920, 640), - (1280, 640), - (960, 640), - (960, 320), - (640, 320), - (640, 320), - ], - }, - sample_size: Optional[int] = None, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - norm_num_groups: Optional[int] = 32, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - upcast_attention: bool = False, - ): - super().__init__() - - # 1 - Create control unet - self.control_model = UNet2DConditionModel( - sample_size=sample_size, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - attention_head_dim=num_attention_heads, - use_linear_projection=True, - upcast_attention=upcast_attention, - time_embedding_dim=time_embedding_dim, - ) - - # 2 - Do model surgery on control model - # 2.1 - Allow to use the same time information as the base model - adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) - - # 2.2 - Allow for information infusion from base model - - # We concat the output of each base encoder subblocks to the input of the next control encoder subblock - # (We ignore the 1st element, as it represents the `conv_in`.) - extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] - it_extra_input_channels = iter(extra_input_channels) - - for b, block in enumerate(self.control_model.down_blocks): - for r in range(len(block.resnets)): - increase_block_input_in_encoder_resnet( - self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) - ) - - if block.downsamplers: - increase_block_input_in_encoder_downsampler( - self.control_model, block_no=b, by=next(it_extra_input_channels) - ) - - increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) - - # 2.3 - Make group norms work with modified channel sizes - adjust_group_norms(self.control_model) - - # 3 - Gather Channel Sizes - self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") - self.ch_inout_base = base_model_channel_sizes - - # 4 - Build connections between base and control model - self.down_zero_convs_out = nn.ModuleList([]) - self.down_zero_convs_in = nn.ModuleList([]) - self.middle_block_out = nn.ModuleList([]) - self.middle_block_in = nn.ModuleList([]) - self.up_zero_convs_out = nn.ModuleList([]) - self.up_zero_convs_in = nn.ModuleList([]) - - for ch_io_base in self.ch_inout_base["down"]: - self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) - for i in range(len(self.ch_inout_ctrl["down"])): - self.down_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) - ) - - self.middle_block_out = self._make_zero_conv( - self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] - ) - - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) - ) - for i in range(1, len(self.ch_inout_ctrl["down"])): - self.up_zero_convs_out.append( - self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) - ) - - # 5 - Create conditioning hint embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # In the mininal implementation setting, we only need the control model up to the mid block - del self.control_model.up_blocks - del self.control_model.conv_norm_out - del self.control_model.conv_out - - @classmethod - def from_unet( - cls, - unet: UNet2DConditionModel, - conditioning_channels: int = 3, - conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), - controlnet_conditioning_channel_order: str = "rgb", - learn_embedding: bool = False, - time_embedding_mix: float = 1.0, - block_out_channels: Optional[Tuple[int]] = None, - size_ratio: Optional[float] = None, - num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, - norm_num_groups: Optional[int] = None, - ): - r""" - Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. - conditioning_channels (`int`, defaults to 3): - Number of channels of conditioning input (e.g. an image) - conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `controlnet_cond_embedding` layer. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - learn_embedding (`bool`, defaults to `False`): - Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation - of the time embeddings of the control and base model with interpolation parameter - `time_embedding_mix**3`. - time_embedding_mix (`float`, defaults to 1.0): - Linear interpolation parameter used if `learn_embedding` is `True`. - block_out_channels (`Tuple[int]`, *optional*): - Down blocks output channels in control model. Either this or `size_ratio` must be given. - size_ratio (float, *optional*): - When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. - Either this or `block_out_channels` must be given. - num_attention_heads (`Union[int, Tuple[int]]`, *optional*): - The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. - norm_num_groups (int, *optional*, defaults to `None`): - The number of groups to use for the normalization of the control unet. If `None`, - `int(unet.config.norm_num_groups * size_ratio)` is taken. - """ - - # Check input - fixed_size = block_out_channels is not None - relative_size = size_ratio is not None - if not (fixed_size ^ relative_size): - raise ValueError( - "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)." - ) - - # Create model - if block_out_channels is None: - block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] - - # Check that attention heads and group norms match channel sizes - # - attention heads - def attn_heads_match_channel_sizes(attn_heads, channel_sizes): - if isinstance(attn_heads, (tuple, list)): - return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) - else: - return all(c % attn_heads == 0 for c in channel_sizes) - - num_attention_heads = num_attention_heads or unet.config.attention_head_dim - if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): - raise ValueError( - f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." - ) - - # - group norms - def group_norms_match_channel_sizes(num_groups, channel_sizes): - return all(c % num_groups == 0 for c in channel_sizes) - - if norm_num_groups is None: - if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): - norm_num_groups = unet.config.norm_num_groups - else: - norm_num_groups = min(block_out_channels) - - if group_norms_match_channel_sizes(norm_num_groups, block_out_channels): - print( - f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information." - ) - else: - raise ValueError( - f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels." - ) - - def get_time_emb_input_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_1.in_features - - def get_time_emb_dim(unet: UNet2DConditionModel): - return unet.time_embedding.linear_2.out_features - - # Clone params from base unet if - # (i) it's required to build SD or SDXL, and - # (ii) it's not used for the time embedding (as time embedding of control model is never used), and - # (iii) it's not set further below anyway - to_keep = [ - "cross_attention_dim", - "down_block_types", - "sample_size", - "transformer_layers_per_block", - "up_block_types", - "upcast_attention", - ] - kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} - kwargs.update(block_out_channels=block_out_channels) - kwargs.update(num_attention_heads=num_attention_heads) - kwargs.update(norm_num_groups=norm_num_groups) - - # Add controlnetxs-specific params - kwargs.update( - conditioning_channels=conditioning_channels, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - time_embedding_input_dim=get_time_emb_input_dim(unet), - time_embedding_dim=get_time_emb_dim(unet), - time_embedding_mix=time_embedding_mix, - learn_embedding=learn_embedding, - base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - ) - - return cls(**kwargs) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - return self.control_model.attn_processors - - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - self.control_model.set_attn_processor(processor, _remove_lora) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.control_model.set_default_attn_processor() - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - self.control_model.set_attention_slice(slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (UNet2DConditionModel)): - if value: - module.enable_gradient_checkpointing() - else: - module.disable_gradient_checkpointing() - - def forward( - self, - base_model: UNet2DConditionModel, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - return_dict: bool = True, - ) -> Union[ControlNetXSOutput, Tuple]: - """ - The [`ControlNetModel`] forward method. - - Args: - base_model (`UNet2DConditionModel`): - The base unet model we want to control. - sample (`torch.FloatTensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - controlnet_cond (`torch.FloatTensor`): - The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - How much the control model affects the base model outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - - Returns: - [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # scale control strength - n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) - scale_list = torch.full((n_connections,), conditioning_scale) - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = base_model.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - if self.config.learn_embedding: - ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) - base_temb = base_model.time_embedding(t_emb, timestep_cond) - interpolation_param = self.config.time_embedding_mix**0.3 - - temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) - else: - temb = base_model.time_embedding(t_emb) - - # added time & text embeddings - aug_emb = None - - if base_model.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if base_model.config.class_embed_type == "timestep": - class_labels = base_model.time_proj(class_labels) - - class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) - temb = temb + class_emb - - if base_model.config.addition_embed_type is not None: - if base_model.config.addition_embed_type == "text": - aug_emb = base_model.add_embedding(encoder_hidden_states) - elif base_model.config.addition_embed_type == "text_image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = base_model.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(temb.dtype) - aug_emb = base_model.add_embedding(add_embeds) - elif base_model.config.addition_embed_type == "image": - raise NotImplementedError() - elif base_model.config.addition_embed_type == "image_hint": - raise NotImplementedError() - - temb = temb + aug_emb if aug_emb is not None else temb - - # text embeddings - cemb = encoder_hidden_states - - # Preparation - guided_hint = self.controlnet_cond_embedding(controlnet_cond) - - h_ctrl = h_base = sample - hs_base, hs_ctrl = [], [] - it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map( - iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) - ) - scales = iter(scale_list) - - base_down_subblocks = to_sub_blocks(base_model.down_blocks) - ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) - base_mid_subblocks = to_sub_blocks([base_model.mid_block]) - ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) - base_up_subblocks = to_sub_blocks(base_model.up_blocks) - - # Cross Control - # 0 - conv in - h_base = base_model.conv_in(h_base) - h_ctrl = self.control_model.conv_in(h_ctrl) - if guided_hint is not None: - h_ctrl += guided_hint - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 1 - down - for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base - hs_base.append(h_base) - hs_ctrl.append(h_ctrl) - - # 2 - mid - h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl - for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock - h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock - h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base - - # 3 - up - for i, m_base in enumerate(base_up_subblocks): - h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder - h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder - h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) - - h_base = base_model.conv_norm_out(h_base) - h_base = base_model.conv_act(h_base) - h_base = base_model.conv_out(h_base) - - if not return_dict: - return h_base - - return ControlNetXSOutput(sample=h_base) - - def _make_zero_conv(self, in_channels, out_channels=None): - # keep running track of channels sizes - self.in_channels = in_channels - self.out_channels = out_channels or in_channels - - return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) - - @torch.no_grad() - def _check_if_vae_compatible(self, vae: AutoencoderKL): - condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) - vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - compatible = condition_downscale_factor == vae_downscale_factor - return compatible, condition_downscale_factor, vae_downscale_factor - - -class SubBlock(nn.ModuleList): - """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. - Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. - """ - - def __init__(self, ms, *args, **kwargs): - if not is_iterable(ms): - ms = [ms] - super().__init__(ms, *args, **kwargs) - - def forward( - self, - x: torch.Tensor, - temb: torch.Tensor, - cemb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - """Iterate through children and pass correct information to each.""" - for m in self: - if isinstance(m, ResnetBlock2D): - x = m(x, temb) - elif isinstance(m, Transformer2DModel): - x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample - elif isinstance(m, Downsample2D): - x = m(x) - elif isinstance(m, Upsample2D): - x = m(x) - else: - raise ValueError( - f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" - ) - - return x - - -def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): - unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) - - -def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): - """Increase channels sizes to allow for additional concatted information from base model""" - r = unet.down_blocks[block_no].resnets[resnet_idx] - old_norm1, old_conv1 = r.norm1, r.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - # conv1 - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - for a in conv1_args: - assert hasattr(old_conv1, a) - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) - unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here - - -def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): - """Increase channels sizes to allow for additional concatted information from base model""" - old_down = unet.down_blocks[block_no].downsamplers[0].conv - - args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - args.append("lora_layer") - - for a in args: - assert hasattr(old_down, a) - kwargs = {a: getattr(old_down, a) for a in args} - kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. - kwargs["in_channels"] += by # surgery done here - # swap old with new modules - unet.down_blocks[block_no].downsamplers[0].conv = ( - nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) - ) - unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here - - -def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): - """Increase channels sizes to allow for additional concatted information from base model""" - m = unet.mid_block.resnets[0] - old_norm1, old_conv1 = m.norm1, m.conv1 - # norm - norm_args = "num_groups num_channels eps affine".split(" ") - for a in norm_args: - assert hasattr(old_norm1, a) - norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} - norm_kwargs["num_channels"] += by # surgery done here - conv1_args = [ - "in_channels", - "out_channels", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "bias", - "padding_mode", - ] - if not USE_PEFT_BACKEND: - conv1_args.append("lora_layer") - - conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} - conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. - conv1_kwargs["in_channels"] += by # surgery done here - # conv_shortcut - # as we changed the input size of the block, the input and output sizes are likely different, - # therefore we need a conv_shortcut (simply adding won't work) - conv_shortcut_args_kwargs = { - "in_channels": conv1_kwargs["in_channels"], - "out_channels": conv1_kwargs["out_channels"], - # default arguments from resnet.__init__ - "kernel_size": 1, - "stride": 1, - "padding": 0, - "bias": True, - } - # swap old with new modules - unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) - unet.mid_block.resnets[0].conv1 = ( - nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) - ) - unet.mid_block.resnets[0].conv_shortcut = ( - nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) - ) - unet.mid_block.resnets[0].in_channels += by # surgery done here - - -def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): - def find_denominator(number, start): - if start >= number: - return number - while start != 0: - residual = number % start - if residual == 0: - return start - start -= 1 - - for block in [*unet.down_blocks, unet.mid_block]: - # resnets - for r in block.resnets: - if r.norm1.num_groups < max_num_group: - r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) - - if r.norm2.num_groups < max_num_group: - r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) - - # transformers - if hasattr(block, "attentions"): - for a in block.attentions: - if a.norm.num_groups < max_num_group: - a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) - - -def is_iterable(o): - if isinstance(o, str): - return False - try: - iter(o) - return True - except TypeError: - return False - - -def to_sub_blocks(blocks): - if not is_iterable(blocks): - blocks = [blocks] - - sub_blocks = [] - - for b in blocks: - if hasattr(b, "resnets"): - if hasattr(b, "attentions") and b.attentions is not None: - for r, a in zip(b.resnets, b.attentions): - sub_blocks.append([r, a]) - - num_resnets = len(b.resnets) - num_attns = len(b.attentions) - - if num_resnets > num_attns: - # we can have more resnets than attentions, so add each resnet as separate subblock - for i in range(num_attns, num_resnets): - sub_blocks.append([b.resnets[i]]) - else: - for r in b.resnets: - sub_blocks.append([r]) - - # upsamplers are part of the same subblock - if hasattr(b, "upsamplers") and b.upsamplers is not None: - for u in b.upsamplers: - sub_blocks[-1].extend([u]) - - # downsamplers are own subblock - if hasattr(b, "downsamplers") and b.downsamplers is not None: - for d in b.downsamplers: - sub_blocks.append([d]) - - return list(map(SubBlock, sub_blocks)) - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3bf67dfc1cdc..2b456f4c3d08 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -128,12 +128,6 @@ "StableDiffusionXLControlNetPipeline", ] ) - _import_structure["controlnet_xs"].extend( - [ - "StableDiffusionControlNetXSPipeline", - "StableDiffusionXLControlNetXSPipeline", - ] - ) _import_structure["deepfloyd_if"] = [ "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -361,10 +355,6 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) - from .controlnet_xs import ( - StableDiffusionControlNetXSPipeline, - StableDiffusionXLControlNetXSPipeline, - ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py deleted file mode 100644 index bf3ac5050506..000000000000 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ /dev/null @@ -1,946 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel -from ...models.lora import adjust_lora_scale_text_encoder -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" - >>> negative_prompt = "low quality, bad quality, sketches" - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" - ... ) - - >>> # initialize the models and pipeline - >>> controlnet_conditioning_scale = 0.5 - >>> controlnet = ControlNetXSModel.from_pretrained( - ... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 - ... ) - >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - >>> pipe.enable_model_cpu_offload() - - >>> # get canny image - >>> image = np.array(image) - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - >>> # generate image - >>> image = pipe( - ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image - ... ).images[0] - ``` -""" - - -class StableDiffusionControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]): - Provides additional conditioning to the `unet` during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. - """ - - model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" - _optional_components = ["safety_checker", "feature_extractor"] - _exclude_from_cpu_offload = ["safety_checker"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - else: - assert False - - start, end = control_guidance_start, control_guidance_end - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - control_guidance_start: float = 0.0, - control_guidance_end: float = 1.0, - clip_skip: Optional[int] = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set - the corresponding scale as a list. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - image, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end - ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=True, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py deleted file mode 100644 index 58f0f544a5ac..000000000000 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ /dev/null @@ -1,1119 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -from diffusers.utils.import_utils import is_invisible_watermark_available - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel -from ...models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) -from ...models.lora import adjust_lora_scale_text_encoder -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput - - -if is_invisible_watermark_available(): - from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSModel, AutoencoderKL - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" - >>> negative_prompt = "low quality, bad quality, sketches" - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" - ... ) - - >>> # initialize the models and pipeline - >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization - >>> controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) - >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 - ... ) - >>> pipe.enable_model_cpu_offload() - - >>> # get canny image - >>> image = np.array(image) - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # generate image - >>> image = pipe( - ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image - ... ).images[0] - ``` -""" - - -class StableDiffusionXLControlNetXSPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin -): - r""" - Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): - Second frozen text-encoder - ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - tokenizer_2 ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetXSModel`]: - Provides additional conditioning to the `unet` during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): - Whether the negative prompt embeddings should always be set to 0. Also see the config of - `stabilityai/stable-diffusion-xl-base-1-0`. - add_watermarker (`bool`, *optional*): - Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to - watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no - watermarker is used. - """ - - # leave controlnet out on purpose because it iterates with unet - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - text_encoder_2: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetXSModel, - scheduler: KarrasDiffusionSchedulers, - force_zeros_for_empty_prompt: bool = True, - add_watermarker: Optional[bool] = None, - ): - super().__init__() - - vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( - vae - ) - if not vae_compatible: - raise ValueError( - f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() - - if add_watermarker: - self.watermark = StableDiffusionXLWatermarker() - else: - self.watermark = None - - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: procecss multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - prompt_2, - image, - callback_steps, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - self.check_image(image, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetXSModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetXSModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - else: - assert False - - start, end = control_guidance_start, control_guidance_end - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - control_guidance_start: float = 0.0, - control_guidance_end: float = 1.0, - original_size: Tuple[int, int] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Tuple[int, int] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be - accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height - and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in - `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. Anything below 512 pixels won't work well for - [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - and checkpoints that are not specifically fine-tuned on low resolutions. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. Anything below 512 pixels won't work well for - [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - and checkpoints that are not specifically fine-tuned on low resolutions. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 5.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` - and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, pooled text embeddings are generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt - weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. - control_guidance_start (`float`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - For most cases, `target_size` should be set to the desired height and width of the generated image. If - not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in - section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - To negatively condition the generation process based on a specific image resolution. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's - micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - To negatively condition the generation process based on a target image resolution. It should be as same - as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more - information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is - returned, otherwise a `tuple` is returned containing the output images. - """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - image, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - - # 4. Prepare image - if isinstance(controlnet, ControlNetXSModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - height, width = image.shape[-2:] - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Prepare added time ids & embeddings - if isinstance(image, list): - original_size = original_size or image[0].shape[-2:] - else: - original_size = original_size or image.shape[-2:] - target_size = target_size or (height, width) - - add_text_embeds = pooled_prompt_embeds - if self.text_encoder_2 is None: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - else: - text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - - # predict the noise residual - dont_control = ( - i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end - ) - if dont_control: - noise_pred = self.unet( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - else: - noise_pred = self.controlnet( - base_model=self.unet, - sample=latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=True, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusionXLPipelineOutput(images=image) From 9b81a7fe53b26c9365b64ea7fb77219df9aa2aab Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 10:28:50 +0000 Subject: [PATCH 04/11] update --- .../controlnetxs/infer_sdxl_controlnetxs.py | 2 +- .../pipelines/controlnet_xs/__init__.py | 68 ------------------- 2 files changed, 1 insertion(+), 69 deletions(-) delete mode 100644 src/diffusers/pipelines/controlnet_xs/__init__.py diff --git a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py index b9ace7959d7a..531c8ec99f05 100644 --- a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py +++ b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py @@ -50,4 +50,4 @@ image=canny_image, num_inference_steps=num_inference_steps ).images[0] -image.save("cnxs_sd.canny.png") +image.save("cnxs_sdxl.canny.png") diff --git a/src/diffusers/pipelines/controlnet_xs/__init__.py b/src/diffusers/pipelines/controlnet_xs/__init__.py deleted file mode 100644 index 978278b184f9..000000000000 --- a/src/diffusers/pipelines/controlnet_xs/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] - _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] -try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] - - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline - from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline - - try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline - - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) From ebed1a13e964cbd5b0fe74d623bcde47dd710270 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 10:37:40 +0000 Subject: [PATCH 05/11] update --- tests/pipelines/controlnetxs/__init__.py | 0 .../controlnetxs/test_controlnetxs.py | 313 --------------- .../controlnetxs/test_controlnetxs_sdxl.py | 364 ------------------ 3 files changed, 677 deletions(-) delete mode 100644 tests/pipelines/controlnetxs/__init__.py delete mode 100644 tests/pipelines/controlnetxs/test_controlnetxs.py delete mode 100644 tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py diff --git a/tests/pipelines/controlnetxs/__init__.py b/tests/pipelines/controlnetxs/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/pipelines/controlnetxs/test_controlnetxs.py b/tests/pipelines/controlnetxs/test_controlnetxs.py deleted file mode 100644 index 1e9e523b71db..000000000000 --- a/tests/pipelines/controlnetxs/test_controlnetxs.py +++ /dev/null @@ -1,313 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import traceback -import unittest - -import numpy as np -import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer - -from diffusers import ( - AutoencoderKL, - ControlNetXSModel, - DDIMScheduler, - LCMScheduler, - StableDiffusionControlNetXSPipeline, - UNet2DConditionModel, -) -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import ( - enable_full_determinism, - load_image, - load_numpy, - numpy_cosine_similarity_distance, - require_python39_or_higher, - require_torch_2, - require_torch_gpu, - run_test_in_subprocess, - slow, - torch_device, -) -from diffusers.utils.torch_utils import randn_tensor - -from ..pipeline_params import ( - IMAGE_TO_IMAGE_IMAGE_PARAMS, - TEXT_TO_IMAGE_BATCH_PARAMS, - TEXT_TO_IMAGE_IMAGE_PARAMS, - TEXT_TO_IMAGE_PARAMS, -) -from ..test_pipelines_common import ( - PipelineKarrasSchedulerTesterMixin, - PipelineLatentTesterMixin, - PipelineTesterMixin, -) - - -enable_full_determinism() - - -# Will be run via run_test_in_subprocess -def _test_stable_diffusion_compile(in_queue, out_queue, timeout): - error = None - try: - _ = in_queue.get(timeout=timeout) - - controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny") - - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet - ) - pipe.to("cuda") - pipe.set_progress_bar_config(disable=None) - - pipe.unet.to(memory_format=torch.channels_last) - pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - - pipe.controlnet.to(memory_format=torch.channels_last) - pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ).resize((512, 512)) - - output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np") - image = output.images[0] - - assert image.shape == (512, 512, 3) - - expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" - ) - expected_image = np.resize(expected_image, (512, 512, 3)) - - assert np.abs(expected_image - image).max() < 1.0 - - except Exception: - error = f"{traceback.format_exc()}" - - results = {"error": error} - out_queue.put(results, timeout=timeout) - out_queue.join() - - -@unittest.skip("Move to Community Pipelines") -class ControlNetXSPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase -): - pipeline_class = StableDiffusionControlNetXSPipeline - params = TEXT_TO_IMAGE_PARAMS - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - - def get_dummy_components(self, time_cond_proj_dim=None): - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(4, 8), - layers_per_block=2, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, - norm_num_groups=1, - time_cond_proj_dim=time_cond_proj_dim, - ) - torch.manual_seed(0) - controlnet = ControlNetXSModel.from_unet( - unet=unet, - time_embedding_mix=0.95, - learn_embedding=True, - size_ratio=0.5, - conditioning_embedding_out_channels=(16, 32), - num_attention_heads=2, - ) - torch.manual_seed(0) - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[4, 8], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - norm_num_groups=2, - ) - torch.manual_seed(0) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - components = { - "unet": unet, - "controlnet": controlnet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - } - return components - - def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) - - controlnet_embedder_scale_factor = 2 - image = randn_tensor( - (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), - generator=generator, - device=torch.device(device), - ) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 6.0, - "output_type": "numpy", - "image": image, - } - - return inputs - - def test_attention_slicing_forward_pass(self): - return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_attention_forwardGenerator_pass(self): - self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) - - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(expected_max_diff=2e-3) - - def test_controlnet_lcm(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - - components = self.get_dummy_components(time_cond_proj_dim=256) - sd_pipe = StableDiffusionControlNetXSPipeline(**components) - sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs) - image = output.images - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array( - [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786] - ) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - -@unittest.skip("Move to Community Pipelines") -@slow -@require_torch_gpu -class ControlNetXSPipelineSlowTests(unittest.TestCase): - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def test_canny(self): - controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny") - - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet - ) - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) - - output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) - - image = output.images[0] - - assert image.shape == (768, 512, 3) - - original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701]) - - max_diff = numpy_cosine_similarity_distance(original_image, expected_image) - assert max_diff < 1e-4 - - def test_depth(self): - controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth") - - pipe = StableDiffusionControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet - ) - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "Stormtrooper's lecture" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" - ) - - output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) - - image = output.images[0] - - assert image.shape == (512, 512, 3) - - original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703]) - - max_diff = numpy_cosine_similarity_distance(original_image, expected_image) - assert max_diff < 1e-4 - - @require_python39_or_higher - @require_torch_2 - def test_stable_diffusion_compile(self): - run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None) diff --git a/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py deleted file mode 100644 index 06e957ce4edc..000000000000 --- a/tests/pipelines/controlnetxs/test_controlnetxs_sdxl.py +++ /dev/null @@ -1,364 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import unittest - -import numpy as np -import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -from diffusers import ( - AutoencoderKL, - ControlNetXSModel, - EulerDiscreteScheduler, - StableDiffusionXLControlNetXSPipeline, - UNet2DConditionModel, -) -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device -from diffusers.utils.torch_utils import randn_tensor - -from ..pipeline_params import ( - IMAGE_TO_IMAGE_IMAGE_PARAMS, - TEXT_TO_IMAGE_BATCH_PARAMS, - TEXT_TO_IMAGE_IMAGE_PARAMS, - TEXT_TO_IMAGE_PARAMS, -) -from ..test_pipelines_common import ( - PipelineKarrasSchedulerTesterMixin, - PipelineLatentTesterMixin, - PipelineTesterMixin, - SDXLOptionalComponentsTesterMixin, -) - - -enable_full_determinism() - - -@unittest.skip("Move to Community Pipelines") -class StableDiffusionXLControlNetXSPipelineFastTests( - PipelineLatentTesterMixin, - PipelineKarrasSchedulerTesterMixin, - PipelineTesterMixin, - SDXLOptionalComponentsTesterMixin, - unittest.TestCase, -): - pipeline_class = StableDiffusionXLControlNetXSPipeline - params = TEXT_TO_IMAGE_PARAMS - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - - def get_dummy_components(self): - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - # SD2-specific config below - attention_head_dim=(2, 4), - use_linear_projection=True, - addition_embed_type="text_time", - addition_time_embed_dim=8, - transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 - cross_attention_dim=64, - ) - torch.manual_seed(0) - controlnet = ControlNetXSModel.from_unet( - unet, - time_embedding_mix=0.95, - learn_embedding=True, - size_ratio=0.5, - conditioning_embedding_out_channels=(16, 32), - ) - torch.manual_seed(0) - scheduler = EulerDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - steps_offset=1, - beta_schedule="scaled_linear", - timestep_spacing="leading", - ) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - ) - torch.manual_seed(0) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - # SD2-specific config below - hidden_act="gelu", - projection_dim=32, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - components = { - "unet": unet, - "controlnet": controlnet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - } - return components - - # copied from test_controlnet_sdxl.py - def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) - - controlnet_embedder_scale_factor = 2 - image = randn_tensor( - (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), - generator=generator, - device=torch.device(device), - ) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 6.0, - "output_type": "np", - "image": image, - } - - return inputs - - # copied from test_controlnet_sdxl.py - def test_attention_slicing_forward_pass(self): - return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) - - # copied from test_controlnet_sdxl.py - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_attention_forwardGenerator_pass(self): - self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) - - # copied from test_controlnet_sdxl.py - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(expected_max_diff=2e-3) - - # copied from test_controlnet_sdxl.py - def test_save_load_optional_components(self): - self._test_save_load_optional_components() - - # copied from test_controlnet_sdxl.py - @require_torch_gpu - def test_stable_diffusion_xl_offloads(self): - pipes = [] - components = self.get_dummy_components() - sd_pipe = self.pipeline_class(**components).to(torch_device) - pipes.append(sd_pipe) - - components = self.get_dummy_components() - sd_pipe = self.pipeline_class(**components) - sd_pipe.enable_model_cpu_offload() - pipes.append(sd_pipe) - - components = self.get_dummy_components() - sd_pipe = self.pipeline_class(**components) - sd_pipe.enable_sequential_cpu_offload() - pipes.append(sd_pipe) - - image_slices = [] - for pipe in pipes: - pipe.unet.set_default_attn_processor() - - inputs = self.get_dummy_inputs(torch_device) - image = pipe(**inputs).images - - image_slices.append(image[0, -3:, -3:, -1].flatten()) - - assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 - assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 - - # copied from test_controlnet_sdxl.py - def test_stable_diffusion_xl_multi_prompts(self): - components = self.get_dummy_components() - sd_pipe = self.pipeline_class(**components).to(torch_device) - - # forward with single prompt - inputs = self.get_dummy_inputs(torch_device) - output = sd_pipe(**inputs) - image_slice_1 = output.images[0, -3:, -3:, -1] - - # forward with same prompt duplicated - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = inputs["prompt"] - output = sd_pipe(**inputs) - image_slice_2 = output.images[0, -3:, -3:, -1] - - # ensure the results are equal - assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - - # forward with different prompt - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = "different prompt" - output = sd_pipe(**inputs) - image_slice_3 = output.images[0, -3:, -3:, -1] - - # ensure the results are not equal - assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 - - # manually set a negative_prompt - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - output = sd_pipe(**inputs) - image_slice_1 = output.images[0, -3:, -3:, -1] - - # forward with same negative_prompt duplicated - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - inputs["negative_prompt_2"] = inputs["negative_prompt"] - output = sd_pipe(**inputs) - image_slice_2 = output.images[0, -3:, -3:, -1] - - # ensure the results are equal - assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - - # forward with different negative_prompt - inputs = self.get_dummy_inputs(torch_device) - inputs["negative_prompt"] = "negative prompt" - inputs["negative_prompt_2"] = "different negative prompt" - output = sd_pipe(**inputs) - image_slice_3 = output.images[0, -3:, -3:, -1] - - # ensure the results are not equal - assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 - - # copied from test_stable_diffusion_xl.py - def test_stable_diffusion_xl_prompt_embeds(self): - components = self.get_dummy_components() - sd_pipe = self.pipeline_class(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # forward without prompt embeds - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt"] = 2 * [inputs["prompt"]] - inputs["num_images_per_prompt"] = 2 - - output = sd_pipe(**inputs) - image_slice_1 = output.images[0, -3:, -3:, -1] - - # forward with prompt embeds - inputs = self.get_dummy_inputs(torch_device) - prompt = 2 * [inputs.pop("prompt")] - - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = sd_pipe.encode_prompt(prompt) - - output = sd_pipe( - **inputs, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - ) - image_slice_2 = output.images[0, -3:, -3:, -1] - - # make sure that it's equal - assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4 - - -@unittest.skip("Move to Community Pipelines") -@slow -@require_torch_gpu -class ControlNetSDXLPipelineXSSlowTests(unittest.TestCase): - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def test_canny(self): - controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny") - - pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet - ) - pipe.enable_sequential_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) - - images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images - - assert images[0].shape == (768, 512, 3) - - original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4359, 0.4335, 0.4609, 0.4515, 0.4669, 0.4494, 0.452, 0.4493, 0.4382]) - assert np.allclose(original_image, expected_image, atol=1e-04) - - def test_depth(self): - controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-depth") - - pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet - ) - pipe.enable_sequential_cpu_offload() - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "Stormtrooper's lecture" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" - ) - - images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images - - assert images[0].shape == (512, 512, 3) - - original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4411, 0.3617, 0.2654, 0.266, 0.3449, 0.3898, 0.3745, 0.353, 0.326]) - assert np.allclose(original_image, expected_image, atol=1e-04) From 4c6f1be00d9e13eb0568702839b58f46ed3e210f Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 13:30:35 +0000 Subject: [PATCH 06/11] make style --- .../controlnetxs/infer_sd_controlnetxs.py | 16 ++++++++++------ .../controlnetxs/infer_sdxl_controlnetxs.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/community/controlnetxs/infer_sd_controlnetxs.py b/examples/community/controlnetxs/infer_sd_controlnetxs.py index f456e74db8e1..722b282a3251 100644 --- a/examples/community/controlnetxs/infer_sd_controlnetxs.py +++ b/examples/community/controlnetxs/infer_sd_controlnetxs.py @@ -12,10 +12,16 @@ parser = argparse.ArgumentParser() -parser.add_argument("--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting") +parser.add_argument( + "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" +) parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument("--image_path", type=str, default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") +parser.add_argument( + "--image_path", + type=str, + default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", +) parser.add_argument("--num_inference_steps", type=int, default=50) args = parser.parse_args() @@ -27,9 +33,7 @@ # initialize the models and pipeline controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained( - "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 -) +controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 ) @@ -49,6 +53,6 @@ prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, - num_inference_steps=num_inference_steps + num_inference_steps=num_inference_steps, ).images[0] image.save("cnxs_sd.canny.png") diff --git a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py index 531c8ec99f05..e5b8cfd88223 100644 --- a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py +++ b/examples/community/controlnetxs/infer_sdxl_controlnetxs.py @@ -12,10 +12,16 @@ parser = argparse.ArgumentParser() -parser.add_argument("--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting") +parser.add_argument( + "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" +) parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches") parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7) -parser.add_argument("--image_path", type=str, default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") +parser.add_argument( + "--image_path", + type=str, + default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png", +) parser.add_argument("--num_inference_steps", type=int, default=50) args = parser.parse_args() @@ -26,9 +32,7 @@ image = load_image(args.image_path) # initialize the models and pipeline controlnet_conditioning_scale = args.controlnet_conditioning_scale -controlnet = ControlNetXSModel.from_pretrained( - "UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 -) +controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16) pipe = StableDiffusionControlNetXSPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ) @@ -48,6 +52,6 @@ prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, - num_inference_steps=num_inference_steps + num_inference_steps=num_inference_steps, ).images[0] image.save("cnxs_sdxl.canny.png") From f3fbcac3ad217404e5a79e6eb244a522747eb223 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 14:00:22 +0000 Subject: [PATCH 07/11] remove docs --- docs/source/en/api/pipelines/controlnetxs.md | 39 ---------------- .../en/api/pipelines/controlnetxs_sdxl.md | 45 ------------------- 2 files changed, 84 deletions(-) delete mode 100644 docs/source/en/api/pipelines/controlnetxs.md delete mode 100644 docs/source/en/api/pipelines/controlnetxs_sdxl.md diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md deleted file mode 100644 index 2d4ae7b8ce46..000000000000 --- a/docs/source/en/api/pipelines/controlnetxs.md +++ /dev/null @@ -1,39 +0,0 @@ - - -# ControlNet-XS - -ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. - -Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. - -ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory. - -Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): - -*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* - -This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ - - - -Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. - - - -## StableDiffusionControlNetXSPipeline -[[autodoc]] StableDiffusionControlNetXSPipeline - - all - - __call__ - -## StableDiffusionPipelineOutput -[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md deleted file mode 100644 index 31075c0ef96a..000000000000 --- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md +++ /dev/null @@ -1,45 +0,0 @@ - - -# ControlNet-XS with Stable Diffusion XL - -ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. - -Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. - -ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory. - -Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): - -*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* - -This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ - - - -🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve! - - - - - -Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. - - - -## StableDiffusionXLControlNetXSPipeline -[[autodoc]] StableDiffusionXLControlNetXSPipeline - - all - - __call__ - -## StableDiffusionPipelineOutput -[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput From 12e0ace62115f22cd80cc46848aef4430a2155fc Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Dec 2023 17:19:35 +0000 Subject: [PATCH 08/11] update --- src/diffusers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d6778961dcba..180b210953c1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -459,7 +459,6 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, - ControlNetXSModel, Kandinsky3UNet, ModelMixin, MotionAdapter, From db57b0d517ac6127ea2add1fa34254e0db0afb60 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 27 Dec 2023 07:46:56 +0530 Subject: [PATCH 09/11] move to research folder. --- .../research_projects/controlnetxs/README.md | 16 ++++++++++++++++ .../controlnetxs/README_sdxl.md | 15 +++++++++++++++ .../controlnetxs/controlnetxs.py | 0 .../controlnetxs/infer_sd_controlnetxs.py | 0 .../controlnetxs/infer_sdxl_controlnetxs.py | 0 .../controlnetxs/pipeline_controlnet_xs.py | 0 .../controlnetxs/pipeline_controlnet_xs_sd_xl.py | 0 7 files changed, 31 insertions(+) create mode 100644 examples/research_projects/controlnetxs/README.md create mode 100644 examples/research_projects/controlnetxs/README_sdxl.md rename examples/{community => research_projects}/controlnetxs/controlnetxs.py (100%) rename examples/{community => research_projects}/controlnetxs/infer_sd_controlnetxs.py (100%) rename examples/{community => research_projects}/controlnetxs/infer_sdxl_controlnetxs.py (100%) rename examples/{community => research_projects}/controlnetxs/pipeline_controlnet_xs.py (100%) rename examples/{community => research_projects}/controlnetxs/pipeline_controlnet_xs_sd_xl.py (100%) diff --git a/examples/research_projects/controlnetxs/README.md b/examples/research_projects/controlnetxs/README.md new file mode 100644 index 000000000000..72ed91c01db2 --- /dev/null +++ b/examples/research_projects/controlnetxs/README.md @@ -0,0 +1,16 @@ +# ControlNet-XS + +ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. + +Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. + +ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory. + +Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): + +*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* + +This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ + + +> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file diff --git a/examples/research_projects/controlnetxs/README_sdxl.md b/examples/research_projects/controlnetxs/README_sdxl.md new file mode 100644 index 000000000000..d401c1e76698 --- /dev/null +++ b/examples/research_projects/controlnetxs/README_sdxl.md @@ -0,0 +1,15 @@ +# ControlNet-XS with Stable Diffusion XL + +ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results. + +Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process. + +ControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory. + +Here's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/): + +*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.* + +This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️ + +> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. \ No newline at end of file diff --git a/examples/community/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py similarity index 100% rename from examples/community/controlnetxs/controlnetxs.py rename to examples/research_projects/controlnetxs/controlnetxs.py diff --git a/examples/community/controlnetxs/infer_sd_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py similarity index 100% rename from examples/community/controlnetxs/infer_sd_controlnetxs.py rename to examples/research_projects/controlnetxs/infer_sd_controlnetxs.py diff --git a/examples/community/controlnetxs/infer_sdxl_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py similarity index 100% rename from examples/community/controlnetxs/infer_sdxl_controlnetxs.py rename to examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py diff --git a/examples/community/controlnetxs/pipeline_controlnet_xs.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py similarity index 100% rename from examples/community/controlnetxs/pipeline_controlnet_xs.py rename to examples/research_projects/controlnetxs/pipeline_controlnet_xs.py diff --git a/examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py similarity index 100% rename from examples/community/controlnetxs/pipeline_controlnet_xs_sd_xl.py rename to examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py From 16734f71d92f1954d3c250c73ba8a74ec66c7a02 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 27 Dec 2023 07:48:15 +0530 Subject: [PATCH 10/11] fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 ---------- .../dummy_torch_and_transformers_objects.py | 30 ------------------- 2 files changed, 45 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5bd2f493ce08..d306a3575b1f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,21 +92,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ControlNetXSModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ae6c6c916065..2eb9599658d9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -782,21 +782,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionControlNetXSPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1142,21 +1127,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 557f8d90f390c9c7e52218d8a04244b42c4985e6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 27 Dec 2023 07:53:20 +0530 Subject: [PATCH 11/11] remove _toctree entry. --- docs/source/en/_toctree.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 29e085fbeb7c..0c05f0ef7ffa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -266,10 +266,6 @@ title: ControlNet - local: api/pipelines/controlnet_sdxl title: ControlNet with Stable Diffusion XL - - local: api/pipelines/controlnetxs - title: ControlNet-XS - - local: api/pipelines/controlnetxs_sdxl - title: ControlNet-XS with Stable Diffusion XL - local: api/pipelines/dance_diffusion title: Dance Diffusion - local: api/pipelines/ddim